diff --git a/.github/scripts/domain-check.js b/.github/scripts/domain-check.js new file mode 100644 index 000000000..ed68c2e42 --- /dev/null +++ b/.github/scripts/domain-check.js @@ -0,0 +1,156 @@ +// +// This script is used by pr.yml to determine if a domain must be tested based +// on the files modified in the pull request. +// + +// Given a domain name and set of files, return true if the domain should be +// tested +function matchesPattern(domain, filePaths) { + // filter files that end in .md + filePaths = filePaths.filter( + (filePath) => + !filePath.endsWith(".md") && + !filePath.startsWith("docs/") && + !filePath.startsWith("third-party-programs/"), + ); + // These directories contain domain specific code + const dirs = "(tests/unit_tests|examples|src|include/oneapi/mkl)"; + const domains = "(blas|lapack|rng|dft)"; + // matches changes to the domain of interest or non domain-specific code + const re = new RegExp(`^(${dirs}/${domain}|(?!${dirs}/${domains}))`); + const match = filePaths.some((filePath) => re.test(filePath)); + return match; +} + +// Return the list of files modified in the pull request +async function prFiles(github, context) { + let allFiles = []; + let page = 0; + let filesPerPage = 100; // GitHub's maximum per page + + while (true) { + page++; + const response = await github.rest.pulls.listFiles({ + owner: context.repo.owner, + repo: context.repo.repo, + pull_number: context.payload.pull_request.number, + per_page: filesPerPage, + page: page, + }); + + if (response.data.length === 0) { + break; // Exit the loop if no more files are returned + } + + allFiles = allFiles.concat(response.data.map((file) => file.filename)); + + if (response.data.length < filesPerPage) { + break; // Exit the loop if last page + } + } + + return allFiles; +} + +// Called by pr.yml. See: +// https://github.com/actions/github-script/blob/main/README.md for more +// information on the github and context parameters +module.exports = async ({ github, context, domain }) => { + if (!context.payload.pull_request) { + console.log("Not a pull request. Testing all domains."); + return true; + } + const files = await prFiles(github, context); + const match = matchesPattern(domain, files); + console.log("Domain: ", domain); + console.log("PR files: ", files); + console.log("Match: ", match); + return match; +}; + +// +// Test the matchesPattern function +// +// Run this script with `node domain-check.js` It should exit with code 0 if +// all tests pass. +// +// If you need to change the set of files that are ignored, add a test pattern +// below with positive and negative examples. It is also possible to test by +// setting up a fork and then submitting pull requests that modify files, but +// it requires a lot of manual work. +// +test_patterns = [ + { + domain: "blas", + files: ["tests/unit_tests/blas/test_blas.cpp"], + expected: true, + }, + { + domain: "rng", + files: ["examples/rng/example_rng.cpp"], + expected: true, + }, + { + domain: "lapack", + files: ["include/oneapi/mkl/lapack/lapack.hpp"], + expected: true, + }, + { + domain: "dft", + files: ["src/dft/lapack.hpp"], + expected: true, + }, + { + domain: "dft", + files: ["src/dft/lapack.md"], + expected: false, + }, + { + domain: "blas", + files: ["tests/unit_tests/dft/test_blas.cpp"], + expected: false, + }, + { + domain: "rng", + files: ["examples/blas/example_rng.cpp"], + expected: false, + }, + { + domain: "lapack", + files: ["include/oneapi/mkl/rng/lapack.hpp"], + expected: false, + }, + { + domain: "dft", + files: ["src/lapack/lapack.hpp"], + expected: false, + }, + { + domain: "dft", + files: ["docs/dft/dft.rst"], + expected: false, + }, + { + domain: "dft", + files: ["third-party-programs/dft/dft.rst"], + expected: false, + }, +]; + +function testPattern(test) { + const result = matchesPattern(test.domain, test.files); + if (result !== test.expected) { + console.log("Fail:"); + console.log(" domain:", test.domain); + console.log(" files:", test.files); + console.log(" expected:", test.expected); + console.log(" result:", result); + process.exit(1); + } +} + +if (require.main === module) { + // invoke test for each test pattern + test_patterns.forEach(testPattern); + console.log("All tests pass"); +} diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml new file mode 100644 index 000000000..000d737aa --- /dev/null +++ b/.github/workflows/documentation.yml @@ -0,0 +1,60 @@ +name: Documentation +permissions: read-all + +# Trigger for PR or merge to develop branch +on: + push: + branches: develop + paths: + - 'docs/**' + pull_request: + paths: + - 'docs/**' + workflow_dispatch: + +jobs: + build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@44c2b7a8a4ea60a981eaca3cf939b5f4305c123b # v4.1.5 + - uses: actions/setup-python@82c7e631bb3cdc910f68e0081d67478d79c6982d # v5.1.0 + with: + python-version: '3.11' + cache: 'pip' + - name: Install Dependencies + run: pip install -r docs/requirements.txt + - name: Configure & Build + run: | + cmake -DCMAKE_VERBOSE_MAKEFILE=on -B build docs + cmake --build build + - uses: actions/upload-artifact@65462800fd760344b1a7b4382951275a0abb4808 # v4.3.3 + with: + name: docs + path: build/Documentation/html + + publish: + needs: build + if: github.event_name == 'workflow_dispatch' || github.event_name == 'push' && github.ref == 'refs/heads/develop' + runs-on: ubuntu-latest + permissions: + contents: write + steps: + - uses: actions/checkout@44c2b7a8a4ea60a981eaca3cf939b5f4305c123b # v4.1.5 + with: + ref: gh-pages + path: gh-pages + - name: Remove old site + run: rm -rf gh-pages/* + - uses: actions/download-artifact@65a9edc5881444af0b9093a5e628f2fe47ea3b2e # v4.1.7 + with: + name: docs + path: gh-pages + - name: Push to GitHub Pages + run: | + cd gh-pages + touch .nojekyll + git add . + git config --global user.name "GitHub Actions" + git config --global user.email github-actions@github.com + git commit -m "Update documentation" + git push --force origin gh-pages diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml new file mode 100644 index 000000000..15a301299 --- /dev/null +++ b/.github/workflows/linux.yml @@ -0,0 +1,159 @@ +name: Build tests with Open SYCL + +on: [push, pull_request] + +jobs: + test-with-hipsycl: + name: Run tests with Open SYCL + strategy: + matrix: + clang_version: [13] + rocm_version: ['5.1.1'] + cuda_version: ['11.6'] + os: [ubuntu-20.04, ubuntu-18.04] + runs-on: ${{matrix.os}} + steps: + - uses: actions/checkout@v2 + with: + submodules: 'recursive' + - name: install ROCm + run: | + sudo apt install libnuma-dev cmake unzip + wget -q -O - https://repo.radeon.com/rocm/rocm.gpg.key | sudo apt-key add - + echo 'deb [arch=amd64] https://repo.radeon.com/rocm/apt/${{matrix.rocm_version}}/ ubuntu main' | sudo tee /etc/apt/sources.list.d/rocm.list + sudo apt update + sudo apt install rocm-dev rocblas + - name: install LLVM + run: | + wget https://apt.llvm.org/llvm.sh + chmod +x llvm.sh + sudo ./llvm.sh ${{matrix.clang_version}} + sudo apt install libclang-${{matrix.clang_version}}-dev clang-tools-${{matrix.clang_version}} libomp-${{matrix.clang_version}}-dev + - name: install boost from apt + if: matrix.os == 'ubuntu-20.04' + run: | + sudo apt update + sudo apt install libboost-all-dev + - name: install boost from source + if: matrix.os == 'ubuntu-18.04' + run: | + cd + wget -q https://boostorg.jfrog.io/artifactory/main/release/1.70.0/source/boost_1_70_0.zip + unzip boost_1_70_0.zip + cd boost_1_70_0 + ./bootstrap.sh --prefix=/usr + ./b2 + sudo ./b2 install -j4 + - name: install CUDA + run: | + cd + wget https://developer.download.nvidia.com/compute/cuda/11.6.0/local_installers/cuda_11.6.0_510.39.01_linux.run + sudo sh cuda_11.6.0_510.39.01_linux.run --override --silent --toolkit + cd + - name: test CUDA and cuBLAS version and fix path by symlink + run: | + /usr/local/cuda/bin/nvcc --version + echo $(ls /usr/local/cuda-11.6/targets/x86_64-linux/lib/stubs/libcuda*) + echo $LD_LIBRARY_PATH + sudo ln -s /usr/local/cuda-11.6/targets/x86_64-linux/lib/stubs/libcuda.so /usr/local/cuda-11.6/targets/x86_64-linux/lib/stubs/libcuda.so.1 + - name: build Open SYCL + run: | + cd + git clone https://github.com/OpenSYCL/OpenSYCL.git + cd OpenSYCL + mkdir build && cd build + cmake -DCMAKE_CXX_COMPILER=/usr/bin/clang++-${{matrix.clang_version}} \ + -DCLANG_EXECUTABLE_PATH=/usr/bin/clang++-${{matrix.clang_version}} \ + -DLLVM_DIR=/usr/lib/llvm-${{matrix.clang_version}}/cmake \ + -DCUDA_CUDA_LIBRARY=/usr/local/cuda-${{matrix.cuda_version}}/lib64/stubs/libcuda.so \ + -DWITH_CUDA_BACKEND=ON \ + -DWITH_ROCM_BACKEND=OFF \ + -DWITH_LEVEL_ZERO_BACKEND=OFF \ + -DCMAKE_INSTALL_PREFIX=/opt/OpenSYCL \ + -DROCM_PATH=/opt/rocm-${{matrix.rocm_version}} \ + -DCUDA_PATH=/usr/local/cuda-${{matrix.cuda_version}} .. + make -j2 install + - name: Build Open SYCL CPU Tests + run: | + cd $HOME/OpenSYCL/build + mkdir tests-omp && cd tests-omp + cmake \ + -DOPENSYCL_TARGETS="omp" \ + -DOpenSYCL_DIR=/opt/OpenSYCL/lib/cmake/OpenSYCL ../../tests + make -j 2 VERBOSE=ON + - name: Open SYCL tests for CPU + run: | + cd $HOME/OpenSYCL/build/tests-omp + LD_LIBRARY_PATH=$HOME/OpenSYCL/build/install/lib ./sycl_tests + - name: install LAPACK (for CBLAS) + run: | + cd + git clone https://github.com/Reference-LAPACK/lapack.git + cd lapack + mkdir build && cd build + cmake -DCMAKE_INSTALL_LIBDIR=/opt/lapack -DBUILD_SHARED_LIBS=ON -DCBLAS=ON .. + sudo cmake --build . -j --target install + - name: clone and build oneMKL with cuBLAS backend + env: + rocblas_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/rocblas/ + hip_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/hip/ + AMDDeviceLibs_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/AMDDeviceLibs/ + amd_comgr_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/amd_comgr/ + hsa-runtime64_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/hsa-runtime64/ + run: | + cd + git clone https://github.com/oneapi-src/oneMKL.git + cd oneMKL + mkdir build + cd build + cmake \ + -DCMAKE_CXX_COMPILER=/usr/bin/clang++-${{matrix.clang_version}} \ + -DCMAKE_C_COMPILER=/usr/bin/clang-${{matrix.clang_version}} \ + -DCUDA_CUDA_LIBRARY=/usr/local/cuda-${{matrix.cuda_version}}/lib64/stubs/libcuda.so \ + -DENABLE_CUBLAS_BACKEND=True \ + -DENABLE_CURAND_BACKEND=False \ + -DENABLE_CUSOLVER_BACKEND=False \ + -DENABLE_MKLGPU_BACKEND=False \ + -DENABLE_MKLCPU_BACKEND=False \ + -DENABLE_NETLIB_BACKEND=False \ + -DENABLE_ROCBLAS_BACKEND=False \ + -DBUILD_SHARED_LIBS=ON \ + -DTARGET_DOMAINS=blas \ + -DREF_BLAS_ROOT=/opt/lapack/ \ + -DhipSYCL_DIR=/opt/OpenSYCL/lib/cmake/hipSYCL \ + -DHIPSYCL_TARGETS="omp;cuda:sm_60" \ + -DONEMKL_SYCL_IMPLEMENTATION=hipsycl .. + cmake --build . -j2 + cmake --install . --prefix /opt/oneMKL + - name: clone and build oneMKL with rocBLAS backend + env: + rocblas_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/rocblas/ + hip_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/hip/ + AMDDeviceLibs_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/AMDDeviceLibs/ + amd_comgr_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/amd_comgr/ + hsa-runtime64_DIR: /opt/rocm-${{matrix.rocm_version}}/lib/cmake/hsa-runtime64/ + run: | + cd + git clone https://github.com/oneapi-src/oneMKL.git oneMKL-rocBLAS + cd oneMKL-rocBLAS + mkdir build + cd build + cmake \ + -DCMAKE_CXX_COMPILER=/usr/bin/clang++-${{matrix.clang_version}} \ + -DCMAKE_C_COMPILER=/usr/bin/clang-${{matrix.clang_version}} \ + -DROCBLAS_LIBRARIES=/opt/rocm-${{matrix.rocm_version}}/lib/cmake/rocblas/ \ + -DENABLE_CUBLAS_BACKEND=False \ + -DENABLE_CURAND_BACKEND=False \ + -DENABLE_CUSOLVER_BACKEND=False \ + -DENABLE_MKLGPU_BACKEND=False \ + -DENABLE_MKLCPU_BACKEND=False \ + -DENABLE_NETLIB_BACKEND=False \ + -DENABLE_ROCBLAS_BACKEND=True \ + -DBUILD_SHARED_LIBS=ON \ + -DTARGET_DOMAINS=blas \ + -DREF_BLAS_ROOT=/opt/lapack/ \ + -DhipSYCL_DIR=/opt/OpenSYCL/lib/cmake/hipSYCL \ + -DHIPSYCL_TARGETS="omp;hip:gfx906" \ + -DONEMKL_SYCL_IMPLEMENTATION=hipsycl .. + cmake --build . -j2 + cmake --install . --prefix /opt/oneMKL-rocBLAS diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 000000000..0da2153c6 --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,84 @@ +name: "PR Tests" +permissions: read-all + +# Trigger for PR and merge to develop branch +on: + push: + branches: develop + pull_request: + workflow_dispatch: + +env: + CTEST_OUTPUT_ON_FAILURE: 1 + LAPACK_VERSION: 3.12.0 + PARALLEL: -j 2 + +jobs: + unit-tests: + runs-on: ubuntu-latest + # One runner for each domain + strategy: + matrix: + include: + - config: portBLAS + domain: blas + build_options: -DREF_BLAS_ROOT=${PWD}/lapack/install -DENABLE_PORTBLAS_BACKEND=ON -DENABLE_MKLCPU_BACKEND=OFF -DPORTBLAS_TUNING_TARGET=INTEL_CPU + - config: portFFT + domain: dft + build_options: -DENABLE_PORTFFT_BACKEND=ON -DENABLE_MKLCPU_BACKEND=OFF + test_options: -R 'DFT/CT/.*ComputeTests_in_place_COMPLEX.COMPLEX_SINGLE_in_place_buffer.sizes_8_batches_1*' + - config: MKL BLAS + domain: blas + build_options: -DREF_BLAS_ROOT=${PWD}/lapack/install + - config: MKL DFT + domain: dft + - config: MKL LAPACK + domain: lapack + build_options: -DREF_LAPACK_ROOT=${PWD}/lapack/install + - config: MKL RNG + domain: rng + name: unit tests ${{ matrix.config }} CPU + steps: + - uses: actions/checkout@44c2b7a8a4ea60a981eaca3cf939b5f4305c123b # v4.1.5 + - name: Check if the changes affect this domain + id: domain_check + uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1 + with: + script: | + const domainCheck = require('.github/scripts/domain-check.js') + return domainCheck({github, context, domain: "${{ matrix.domain }}"}) + - name: Restore netlib from cache + id: cache-lapack + uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # v4.0.2 + with: + path: lapack/install + key: lapack-${{ env.LAPACK_VERSION }} + - name: Install netlib + if: steps.domain_check.outputs.result == 'true' && steps.cache-lapack.outputs.cache-hit != 'true' + run: | + curl -sL https://github.com/Reference-LAPACK/lapack/archive/refs/tags/v${LAPACK_VERSION}.tar.gz | tar zx + SHARED_OPT="lapack-${LAPACK_VERSION} -DBUILD_SHARED_LIBS=on -DCBLAS=on -DLAPACKE=on -DCMAKE_INSTALL_PREFIX=${PWD}/lapack/install" + # 32 bit int + cmake ${SHARED_OPT} -B lapack/build32 + cmake --build lapack/build32 ${PARALLEL} --target install + # 64 bit int + cmake ${SHARED_OPT} -DBUILD_INDEX64=on -B lapack/build64 + cmake --build lapack/build64 ${PARALLEL} --target install + - name: Install oneapi + if: steps.domain_check.outputs.result == 'true' + uses: rscohn2/setup-oneapi@2ad0cf6b74bc2426bdcee825cf88f9db719dd727 # v0.1.0 + with: + components: | + icx@2024.1.0 + mkl@2024.1.0 + - name: Configure/Build for a domain + if: steps.domain_check.outputs.result == 'true' + run: | + source /opt/intel/oneapi/setvars.sh + cmake -DTARGET_DOMAINS=${{ matrix.domain }} -DENABLE_MKLGPU_BACKEND=off -DCMAKE_VERBOSE_MAKEFILE=on ${{ matrix.build_options }} -B build + cmake --build build ${PARALLEL} + - name: Run tests + if: steps.domain_check.outputs.result == 'true' + run: | + source /opt/intel/oneapi/setvars.sh + ctest --test-dir build ${{ matrix.test_options }} diff --git a/.github/workflows/slack-pr.yaml b/.github/workflows/slack-pr.yaml new file mode 100644 index 000000000..4c5f3df7d --- /dev/null +++ b/.github/workflows/slack-pr.yaml @@ -0,0 +1,43 @@ +#=============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#=============================================================================== + +name: Slack PR Notification +on: + # use pull_request_target to run on PRs from forks and have access to secrets + pull_request_target: + types: [labeled] + +env: + SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} + channel: "onemkl" + +permissions: + pull-requests: read + +jobs: + rfc: + name: RFC Notification + runs-on: ubuntu-latest + # Trigger when labeling a PR with "RFC" + if: | + github.event.action == 'labeled' && + contains(toJson(github.event.pull_request.labels.*.name), '"RFC"') + steps: + - name: Notify Slack + uses: slackapi/slack-github-action@70cd7be8e40a46e8b0eced40b0de447bdb42f68e # v1.26.0 + with: + channel-id: ${{ env.channel }} + slack-message: "${{ github.actor }} posted a RFC: ${{ github.event.pull_request.title }}. URL: ${{ github.event.pull_request.html_url }}" diff --git a/.gitignore b/.gitignore index 2a3a157c9..631826c77 100644 --- a/.gitignore +++ b/.gitignore @@ -16,4 +16,4 @@ .git/ # Build -build/ +build*/ diff --git a/CMakeLists.txt b/CMakeLists.txt index 6f961d77b..79af06f6a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -20,12 +20,6 @@ cmake_minimum_required (VERSION 3.13) -# Set flags from Conan if Conan is used -if(EXISTS "${CMAKE_BINARY_DIR}/conanbuildinfo.cmake") - include("${CMAKE_BINARY_DIR}/conanbuildinfo.cmake" REQUIRED) - conan_basic_setup() -endif() - # Define build type set(DEFAULT_BUILD_TYPE "Release") @@ -41,20 +35,31 @@ endif() option(BUILD_SHARED_LIBS "Build dynamic libraries" ON) ## Backends -option(ENABLE_MKLCPU_BACKEND "" ON) -option(ENABLE_MKLGPU_BACKEND "" ON) +option(ENABLE_MKLCPU_BACKEND "Enable the Intel oneMKL CPU backend for supported interfaces" ON) +option(ENABLE_MKLGPU_BACKEND "Enable the Intel oneMKL GPU backend for supported interfaces" ON) if(ENABLE_MKLCPU_BACKEND) - option(ENABLE_MKLCPU_THREAD_TBB "" ON) + option(ENABLE_MKLCPU_THREAD_TBB "Enable the use of Intel TBB with the oneMKL CPU backend" ON) endif() -option(ENABLE_CUBLAS_BACKEND "" OFF) -option(ENABLE_CUSOLVER_BACKEND "" OFF) +# blas +option(ENABLE_CUBLAS_BACKEND "Enable the cuBLAS backend for the BLAS interface" OFF) +option(ENABLE_ROCBLAS_BACKEND "Enable the rocBLAS backend for the BLAS interface" OFF) +option(ENABLE_NETLIB_BACKEND "Enable the Netlib backend for the BLAS interface" OFF) +option(ENABLE_PORTBLAS_BACKEND "Enable the portBLAS backend for the BLAS interface. Cannot be used with other BLAS backends." OFF) + +# rand +option(ENABLE_CURAND_BACKEND "Enable the cuRAND backend for the RNG interface" OFF) +option(ENABLE_ROCRAND_BACKEND "Enable the rocRAND backend for the RNG interface" OFF) + +# lapack +option(ENABLE_CUSOLVER_BACKEND "Enable the cuSOLVER backend for the LAPACK interface" OFF) +option(ENABLE_ROCSOLVER_BACKEND "Enable the rocSOLVER backend for the LAPACK interface" OFF) + +# dft +option(ENABLE_CUFFT_BACKEND "Enable the cuFFT backend for the DFT interface" OFF) +option(ENABLE_ROCFFT_BACKEND "Enable the rocFFT backend for the DFT interface" OFF) +option(ENABLE_PORTFFT_BACKEND "Enable the portFFT DFT backend for the DFT interface. Cannot be used with other DFT backends." OFF) -option(ENABLE_ROCBLAS_BACKEND "" OFF) -option(ENABLE_CURAND_BACKEND "" OFF) -option(ENABLE_ROCRAND_BACKEND "" OFF) -option(ENABLE_ROCSOLVER_BACKEND "" OFF) -option(ENABLE_NETLIB_BACKEND "" OFF) set(ONEMKL_SYCL_IMPLEMENTATION "dpc++" CACHE STRING "Name of the SYCL compiler") set(HIP_TARGETS "" CACHE STRING "Target HIP architectures") @@ -73,7 +78,8 @@ if(ENABLE_MKLCPU_BACKEND OR ENABLE_MKLGPU_BACKEND OR ENABLE_CUBLAS_BACKEND OR ENABLE_ROCBLAS_BACKEND - OR ENABLE_NETLIB_BACKEND) + OR ENABLE_NETLIB_BACKEND + OR ENABLE_PORTBLAS_BACKEND) list(APPEND DOMAINS_LIST "blas") endif() if(ENABLE_MKLCPU_BACKEND @@ -89,9 +95,33 @@ if(ENABLE_MKLCPU_BACKEND list(APPEND DOMAINS_LIST "rng") endif() if(ENABLE_MKLGPU_BACKEND - OR ENABLE_MKLCPU_BACKEND) + OR ENABLE_MKLCPU_BACKEND + OR ENABLE_CUFFT_BACKEND + OR ENABLE_ROCFFT_BACKEND + OR ENABLE_PORTFFT_BACKEND) list(APPEND DOMAINS_LIST "dft") endif() +if(ENABLE_MKLCPU_BACKEND + OR ENABLE_MKLGPU_BACKEND) + list(APPEND DOMAINS_LIST "sparse_blas") +endif() + +if(ENABLE_PORTBLAS_BACKEND + AND (ENABLE_MKLCPU_BACKEND + OR ENABLE_MKLGPU_BACKEND + OR ENABLE_CUBLAS_BACKEND + OR ENABLE_ROCBLAS_BACKEND + OR ENABLE_NETLIB_BACKEND)) + message(FATAL_ERROR "ENABLE_PORTBLAS_BACKEND cannot be enabled at the same time as other BLAS backends.") +endif() + +if (ENABLE_PORTFFT_BACKEND + AND (ENABLE_MKLCPU_BACKEND + OR ENABLE_MKLGPU_BACKEND + OR ENABLE_ROCFFT_BACKEND + OR ENABLE_CUFFT_BACKEND)) + message(FATAL_ERROR "ENABLE_PORTFFT_BACKEND cannot be enabled at the same time as other DFT backends.") +endif() # Define required CXX compilers before project if(CMAKE_CXX_COMPILER OR NOT ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++") @@ -99,8 +129,8 @@ if(CMAKE_CXX_COMPILER OR NOT ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++") string(REPLACE "\\" "/" CMAKE_CXX_COMPILER ${CMAKE_CXX_COMPILER}) endif() else() - if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_ROCBLAS_BACKEND - OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND) + if(ENABLE_CUBLAS_BACKEND OR ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND OR ENABLE_CUFFT_BACKEND + OR ENABLE_ROCBLAS_BACKEND OR ENABLE_ROCRAND_BACKEND OR ENABLE_ROCSOLVER_BACKEND OR ENABLE_ROCFFT_BACKEND) set(CMAKE_CXX_COMPILER "clang++") elseif(ENABLE_MKLGPU_BACKEND) if(UNIX) @@ -167,18 +197,14 @@ if(WIN32 AND ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++") endforeach() set(CMAKE_CXX_COMPILE_OBJECT " -fsycl /nologo /EHsc /Fo -c ") set(CMAKE_CXX_CREATE_STATIC_LIBRARY "lib /nologo /out:") - set(CMAKE_CXX_LINK_EXECUTABLE " -fsycl -fsycl-device-code-split=per_kernel /nologo -o ") - set(MKL_SYCL_LIB "") - if(ENABLE_MKLGPU_BACKEND OR ENABLE_MKLCPU_BACKEND) - find_library(MKL_SYCL_LIB NAMES mkl_sycl - HINTS $ENV{MKLROOT} ${MKL_ROOT} - PATH_SUFFIXES lib/intel64) + if(CMAKE_VERSION VERSION_LESS "3.25.2") + set(CMAKE_CXX_LINK_EXECUTABLE " -fsycl -fsycl-device-code-split=per_kernel /nologo -o ") + set(CMAKE_CXX_CREATE_SHARED_LIBRARY " -fsycl -fsycl-device-code-split=per_kernel /nologo /link /out: /implib: /pdb: /dll /version:.") endif() - set(CMAKE_CXX_CREATE_SHARED_LIBRARY " -fsycl -fsycl-device-code-split=per_kernel /nologo ${MKL_SYCL_LIB} /link /out: /implib: /pdb: /dll /version:. ") endif() # Temporary disable sycl 2020 deprecations warnings for cuSOLVER and rocSOLVER -if(ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++" AND (ENABLE_CUSOLVER_BACKEND OR ENABLE_ROCSOLVER_BACKEND)) +if(ONEMKL_SYCL_IMPLEMENTATION STREQUAL "dpc++" AND (ENABLE_ROCSOLVER_BACKEND)) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DSYCL2020_DISABLE_DEPRECATION_WARNINGS") endif() @@ -203,6 +229,32 @@ else() endif() message(STATUS "TARGET_DOMAINS: ${TARGET_DOMAINS}") +# Include Intel oneMKL +if(ENABLE_MKLGPU_BACKEND OR ENABLE_MKLCPU_BACKEND) + set(MKL_ARCH intel64) + set(MKL_INTERFACE ilp64) + if(ENABLE_MKLCPU_THREAD_TBB) + set(MKL_THREADING tbb_thread) + else() + set(MKL_THREADING sequential) + endif() + if(BUILD_SHARED_LIBS AND NOT WIN32) + set(MKL_LINK dynamic) + else() + set(MKL_LINK static) + endif() + # Enable SYCL API + set(DPCPP_COMPILER ON) + set(SYCL_COMPILER ON) + # In case Intel oneMKL package doesn't include MKLConfig, + # use MKLConfig from the repo + find_package(MKL REQUIRED + HINTS ${MKL_ROOT}/lib/cmake + ${MKL_ROOT}/lib/cmake/mkl + $ENV{MKLROOT} + ${PROJECT_SOURCE_DIR}/cmake/mkl) +endif() + # Set output directories for the project set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index deb61129b..9a41383bd 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -320,6 +320,6 @@ for (int i = 0; i < loop_size; i++) ...; ## Unit Tests -oneMKL uses GoogleTest for functional testing. +oneMKL uses GoogleTest for functional testing. For more information about how to build and run Unit Tests please see [Building and Running Tests](https://oneapi-src.github.io/oneMKL/building_and_running_tests.html). Be sure to extend the existing tests when fixing an issue, adding a new interface or new implementation under existing interfaces. diff --git a/README.md b/README.md index cdfce13e6..e74e3b5ed 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ -# oneAPI Math Kernel Library (oneMKL) Interfaces +UXL Foundation Logo -oneAPI logo +# oneAPI Math Kernel Library (oneMKL) Interfaces -oneMKL Interfaces is an open-source implementation of the oneMKL Data Parallel C++ (DPC++) interface according to the [oneMKL specification](https://spec.oneapi.com/versions/latest/elements/oneMKL/source/index.html). It works with multiple devices (backends) using device-specific libraries underneath. +oneMKL Interfaces is an open-source implementation of the oneMKL Data Parallel C++ (DPC++) interface according to the [oneMKL specification](https://oneapi-spec.uxlfoundation.org/specifications/oneapi/latest/elements/onemkl/source/). It works with multiple devices (backends) using device-specific libraries underneath. -oneMKL is part of [oneAPI](https://oneapi.io). +oneMKL is part of the [UXL Foundation](http://www.uxlfoundation.org).

@@ -18,43 +18,56 @@ oneMKL is part of [oneAPI](https://oneapi.io). - - - - + + + + - - - - + - + - + + + + + - + - + + + + + - + - + + + + + + + + +
oneMKL interfaceoneMKL selectorIntel(R) oneAPI Math Kernel Library for x86 CPUx86 CPUoneMKL interfaceoneMKL selectorIntel(R) oneAPI Math Kernel Library (oneMKL)x86 CPU, Intel GPU
Intel(R) oneAPI Math Kernel Library for Intel GPUIntel GPU
NVIDIA cuBLAS for NVIDIA GPU NVIDIA cuBLAS NVIDIA GPU
NVIDIA cuSOLVER for NVIDIA GPU NVIDIA cuSOLVER NVIDIA GPU
NVIDIA cuRAND for NVIDIA GPU NVIDIA cuRANDNVIDIA GPU
NVIDIA cuFFT NVIDIA GPU
NETLIB LAPACK for x86 CPU NETLIB LAPACK x86 CPU
AMD rocBLAS for AMD GPU AMD rocBLASAMD GPU
AMD rocSOLVER AMD GPU
AMD rocSOLVER for AMD GPU AMD rocRAND AMD GPU
AMD rocRAND for AMD GPU AMD rocFFT AMD GPU
portBLAS x86 CPU, Intel GPU, NVIDIA GPU, AMD GPU
portFFT x86 CPU, Intel GPU, NVIDIA GPU, AMD GPU
@@ -72,65 +85,78 @@ oneMKL is part of [oneAPI](https://oneapi.io). ### Supported Usage Models: +#### Host API + There are two oneMKL selector layer implementations: - **Run-time dispatching**: The application is linked with the oneMKL library and the required backend is loaded at run-time based on device vendor (all libraries should be dynamic). -Example of app.cpp with run-time dispatching: - -```cpp -#include "oneapi/mkl.hpp" - -... -cpu_dev = sycl::device(sycl::cpu_selector()); -gpu_dev = sycl::device(sycl::gpu_selector()); - -sycl::queue cpu_queue(cpu_dev); -sycl::queue gpu_queue(gpu_dev); - -oneapi::mkl::blas::column_major::gemm(cpu_queue, transA, transB, m, ...); -oneapi::mkl::blas::column_major::gemm(gpu_queue, transA, transB, m, ...); -``` -How to build an application with run-time dispatching: - -if OS is Linux, use icpx compiler. If OS is Windows, use icx compiler. -Linux example: -```cmd -$> icpx -fsycl –I$ONEMKL/include app.cpp -$> icpx -fsycl app.o –L$ONEMKL/lib –lonemkl -``` + Example of app.cpp with run-time dispatching: + + ```cpp + #include "oneapi/mkl.hpp" + + ... + cpu_dev = sycl::device(sycl::cpu_selector()); + gpu_dev = sycl::device(sycl::gpu_selector()); + + sycl::queue cpu_queue(cpu_dev); + sycl::queue gpu_queue(gpu_dev); + + oneapi::mkl::blas::column_major::gemm(cpu_queue, transA, transB, m, ...); + oneapi::mkl::blas::column_major::gemm(gpu_queue, transA, transB, m, ...); + ``` + How to build an application with run-time dispatching: + + if OS is Linux, use icpx compiler. If OS is Windows, use icx compiler. + Linux example: + ```cmd + $> icpx -fsycl –I$ONEMKL/include app.cpp + $> icpx -fsycl app.o –L$ONEMKL/lib –lonemkl + ``` - **Compile-time dispatching**: The application uses a templated backend selector API where the template parameters specify the required backends and third-party libraries and the application is linked with the required oneMKL backend wrapper libraries (libraries can be static or dynamic). -Example of app.cpp with compile-time dispatching: - -```cpp -#include "oneapi/mkl.hpp" - -... -cpu_dev = sycl::device(sycl::cpu_selector()); -gpu_dev = sycl::device(sycl::gpu_selector()); - -sycl::queue cpu_queue(cpu_dev); -sycl::queue gpu_queue(gpu_dev); - -oneapi::mkl::backend_selector cpu_selector(cpu_queue); + Example of app.cpp with compile-time dispatching: + + ```cpp + #include "oneapi/mkl.hpp" + + ... + cpu_dev = sycl::device(sycl::cpu_selector()); + gpu_dev = sycl::device(sycl::gpu_selector()); + + sycl::queue cpu_queue(cpu_dev); + sycl::queue gpu_queue(gpu_dev); + + oneapi::mkl::backend_selector cpu_selector(cpu_queue); + + oneapi::mkl::blas::column_major::gemm(cpu_selector, transA, transB, m, ...); + oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector {gpu_queue}, transA, transB, m, ...); + ``` + How to build an application with compile-time dispatching: + + ```cmd + $> clang++ -fsycl –I$ONEMKL/include app.cpp + $> clang++ -fsycl app.o –L$ONEMKL/lib –lonemkl_blas_mklcpu –lonemkl_blas_cublas + ``` + +*Refer to [Selecting a Compiler](https://oneapi-src.github.io/oneMKL/selecting_a_compiler.html) for the choice between `icpx/icx` and `clang++` compilers.* -oneapi::mkl::blas::column_major::gemm(cpu_selector, transA, transB, m, ...); -oneapi::mkl::blas::column_major::gemm(oneapi::mkl::backend_selector {gpu_queue}, transA, transB, m, ...); -``` -How to build an application with compile-time dispatching: +#### Device API -```cmd -$> clang++ -fsycl –I$ONEMKL/include app.cpp -$> clang++ -fsycl app.o –L$ONEMKL/lib –lonemkl_blas_mklcpu –lonemkl_blas_cublas -``` +Header-based and backend-independent Device API can be called within ```sycl kernel``` or work from Host code ([device-rng-usage-model-example](https://spec.oneapi.io/versions/latest/elements/oneMKL/source/domains/rng/device_api/device-rng-usage-model.html#id2)). Currently, the following domains support the Device API: -*Refer to [Selecting a Compiler](https://oneapi-src.github.io/oneMKL/selecting_a_compiler.html) for the choice between `icpx/icx` and `clang++` compilers.* +- **RNG**. To use RNG Device API functionality it's required to include ```oneapi/mkl/rng/device.hpp``` header file. ### Supported Configurations: -Supported domains: BLAS, LAPACK, RNG +Supported domains include: BLAS, LAPACK, RNG, DFT, SPARSE_BLAS + +Supported compilers include: +- [Intel(R) oneAPI DPC++ Compiler](https://software.intel.com/en-us/oneapi/dpc-compiler): Intel proprietary compiler that supports CPUs and Intel GPUs. Intel(R) oneAPI DPC++ Compiler will be referred to as "Intel DPC++" in the "Supported Compiler" column of the tables below. +- [oneAPI DPC++ Compiler](https://github.com/intel/llvm): Open source compiler that supports CPUs and Intel, NVIDIA, and AMD GPUs. oneAPI DPC++ Compiler will be referred to as "Open DPC++" in the "Supported Compiler" column of the tables below. +- [AdaptiveCpp Compiler](https://github.com/AdaptiveCpp/AdaptiveCpp) (formerly known as hipSYCL): Open source compiler that supports CPUs and Intel, NVIDIA, and AMD GPUs.
**Note**: The source code and some documents in this project still use the previous name hipSYCL during this transition period. #### Linux* @@ -140,95 +166,168 @@ Supported domains: BLAS, LAPACK, RNG Domain Backend Library - Supported Link Type Supported Compiler + Supported Link Type - BLAS - x86 CPU - Intel(R) oneAPI Math Kernel Library + BLAS + x86 CPU + Intel(R) oneMKL + Intel DPC++
AdaptiveCpp Dynamic, Static - DPC++, LLVM*, hipSYCL - Intel GPU + NETLIB LAPACK + Intel DPC++
Open DPC++
AdaptiveCpp Dynamic, Static - DPC++ - NVIDIA GPU + portBLAS + Intel DPC++
Open DPC++ + Dynamic, Static + + + Intel GPU + Intel(R) oneMKL + Intel DPC++ + Dynamic, Static + + + portBLAS + Intel DPC++
Open DPC++ + Dynamic, Static + + + NVIDIA GPU NVIDIA cuBLAS + Open DPC++
AdaptiveCpp Dynamic, Static - LLVM*, hipSYCL - x86 CPU - NETLIB LAPACK + portBLAS + Open DPC++ Dynamic, Static - DPC++, LLVM*, hipSYCL - - AMD GPU + + AMD GPU AMD rocBLAS + Open DPC++
AdaptiveCpp + Dynamic, Static + + + portBLAS + Open DPC++ Dynamic, Static - LLVM*, hipSYCL LAPACK x86 CPU - Intel(R) oneAPI Math Kernel Library + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++, LLVM* Intel GPU + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++ NVIDIA GPU NVIDIA cuSOLVER + Open DPC++ Dynamic, Static - LLVM* AMD GPU AMD rocSOLVER + Open DPC++ Dynamic, Static - LLVM* RNG x86 CPU - Intel(R) oneAPI Math Kernel Library + Intel(R) oneMKL + Intel DPC++
AdaptiveCpp Dynamic, Static - DPC++, LLVM*, hipSYCL Intel GPU + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++ NVIDIA GPU NVIDIA cuRAND + Open DPC++
AdaptiveCpp Dynamic, Static - LLVM*, hipSYCL AMD GPU AMD rocRAND + Open DPC++
AdaptiveCpp + Dynamic, Static + + + DFT + x86 CPU + Intel(R) oneMKL + Intel DPC++ + Dynamic, Static + + + portFFT (limited API support) + Intel DPC++ + Dynamic, Static + + + Intel GPU + Intel(R) oneMKL + Intel DPC++ + Dynamic, Static + + + portFFT (limited API support) + Intel DPC++ + Dynamic, Static + + + NVIDIA GPU + NVIDIA cuFFT + Open DPC++ + Dynamic, Static + + + portFFT (limited API support) + Open DPC++ + Dynamic, Static + + + AMD GPU + AMD rocFFT + Open DPC++ + Dynamic, Static + + + portFFT (limited API support) + Open DPC++ + Dynamic, Static + + + SPARSE_BLAS + x86 CPU + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - LLVM*, hipSYCL - DFT Intel GPU - Intel(R) oneAPI Math Kernel Library + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++ @@ -241,58 +340,58 @@ Supported domains: BLAS, LAPACK, RNG Domain Backend Library - Supported Link Type Supported Compiler + Supported Link Type BLAS - x86 CPU - Intel(R) oneAPI Math Kernel Library + x86 CPU + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++, LLVM* - Intel GPU + NETLIB LAPACK + Intel DPC++
Open DPC++ Dynamic, Static - DPC++ - x86 CPU - NETLIB LAPACK + Intel GPU + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++, LLVM* LAPACK x86 CPU - Intel(R) oneAPI Math Kernel Library + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++, LLVM* Intel GPU + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++ RNG x86 CPU - Intel(R) oneAPI Math Kernel Library + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++, LLVM* Intel GPU + Intel(R) oneMKL + Intel DPC++ Dynamic, Static - DPC++ -\* LLVM - [Intel project for LLVM* technology](https://github.com/intel/llvm) with support for [NVIDIA CUDA](https://intel.github.io/llvm-docs/GetStartedGuide.html#build-dpc-toolchain-with-support-for-nvidia-cuda) - --- ### Hardware Platform Support @@ -302,8 +401,9 @@ Supported domains: BLAS, LAPACK, RNG - Intel(R) Core(TM) Processor Family - Intel(R) Xeon(R) Processor Family - Accelerators - - Intel(R) Processor Graphics GEN9 - - NVIDIA(R) TITAN RTX(TM) (Linux* only. cuRAND backend tested also with Quadro and A100 GPUs. Not tested with other NVIDIA GPU families and products.) + - Intel(R) Arc(TM) A-Series Graphics + - Intel(R) Data Center GPU Max Series + - NVIDIA(R) A100 (Linux* only) - AMD(R) GPUs see [here](https://github.com/RadeonOpenCompute/ROCm#hardware-and-software-support) tested on AMD Vega 20 (gfx906) --- @@ -311,19 +411,18 @@ Supported domains: BLAS, LAPACK, RNG #### Linux* -Operating System | CPU Host/Target | Integrated Graphics from Intel (Intel GPU) | NVIDIA GPU -:--- | :--- | :--- | :--- -Ubuntu | 18.04.3, 19.04 | 18.04.3, 19.10 | 18.04.3, 20.04 -SUSE Linux Enterprise Server* | 15 | *Not supported* | *Not supported* -Red Hat Enterprise Linux* (RHEL*) | 8 | *Not supported* | *Not supported* -Linux* kernel | *N/A* | 4.11 or higher | *N/A* +Backend | Supported Operating System +:--- | :--- +x86 CPU | Red Hat Enterprise Linux* 9 (RHEL* 9) +Intel GPU | Ubuntu 22.04 LTS +NVIDIA GPU | Ubuntu 22.04 LTS #### Windows* -Operating System | CPU Host/Target | Integrated Graphics from Intel (Intel GPU) -:--- | :--- | :--- -Microsoft Windows* | 10 (64-bit version only) | 10 (64-bit version only) -Microsoft Windows* Server | 2016, 2019 | *Not supported* +Backend | Supported Operating System +:--- | :--- +x86 CPU | Microsoft Windows* Server 2022 +Intel GPU | Microsoft Windows* 11 --- ### Software Requirements @@ -334,11 +433,6 @@ Microsoft Windows* Server | 2016, 2019 | *Not supported* - - - - - @@ -346,16 +440,16 @@ Microsoft Windows* Server | 2016, 2019 | *Not supported* - + + + + - - - @@ -376,75 +470,53 @@ Microsoft Windows* Server | 2016, 2019 | *Not supported* - - - + - - - - - - + - - +
Using Conan Using CMake Directly
Functional Testing Build Only Documentation
Linux* : GNU* GCC 5.1 or higher
Windows* : MSVS* 2017 or MSVS* 2019 (version 16.5 or newer)
CMake (version 3.13 or newer)
Linux* : GNU* GCC 5.1 or higher
Windows* : MSVS* 2017 or MSVS* 2019 (version 16.5 or newer)
Python 3.6 or higher CMake
Ninja (optional)
Conan C++ package manager GNU* FORTRAN Compiler - Sphinx Operating System Device PackageInstalled by Conan
Linux*/Windows* x86 CPU Intel(R) oneAPI DPC++ Compiler
or
Intel project for LLVM* technology
No Intel(R) oneAPI DPC++ Compiler
or
oneAPI DPC++ Compiler
Intel(R) oneAPI Math Kernel Library Yes
Intel GPU Intel(R) oneAPI DPC++ Compiler No
Intel GPU driver No
Intel(R) oneAPI Math Kernel Library Yes
Linux* only NVIDIA GPU Intel project for LLVM* technology
or
hipSYCL with CUDA backend and dependencies
No oneAPI DPC++ Compiler
or
AdaptiveCpp with CUDA backend and dependencies
AMD GPU Intel project for LLVM* technology
or
hipSYCL with ROCm backend and dependencies
No oneAPI DPC++ Compiler
or
AdaptiveCpp with ROCm backend and dependencies
-*If [Building with Conan](https://oneapi-src.github.io/oneMKL/building_the_project.html#building-with-conan), above packages marked as "No" must be installed manually.* - -*If [Building with CMake](https://oneapi-src.github.io/oneMKL/building_the_project.html#building-with-cmake), above packages must be installed manually.* - -#### Notice for Use of Conan Package Manager -**LEGAL NOTICE: By downloading and using this container or script as applicable (the "Software Package") and the included software or software made available for download, you agree to the terms and conditions of the software license agreements for the Software Package, which may also include notices, disclaimers, or license terms for third party software (together, the "Agreements") included in this README file.** - -**If the Software Package is installed through a silent install, your download and use of the -Software Package indicates your acceptance of the Agreements.** - #### Product and Version Information: -Product | Supported Version | Installed by Conan | Conan Package Source | Package Install Location on Linux* | License -:--- | :--- | :--- | :--- | :--- | :--- -Python | 3.6 or higher | No | *N/A* | *Pre-installed or Installed by user* | [PSF](https://docs.python.org/3.6/license.html) -[Conan C++ Package Manager](https://conan.io/downloads.html) | 1.24 or higher | No | *N/A* | *Installed by user* | [MIT](https://github.com/conan-io/conan/blob/develop/LICENSE.md) -[CMake](https://cmake.org/download/) | 3.13 or higher | Yes
(3.15 or higher) | conan-center | ~/.conan/data or $CONAN_USER_HOME/.conan/data | [The OSI-approved BSD 3-clause License](https://gitlab.kitware.com/cmake/cmake/raw/master/Copyright.txt) -[Ninja](https://ninja-build.org/) | 1.10.0 | Yes | conan-center | ~/.conan/data or $CONAN_USER_HOME/.conan/data | [Apache License v2.0](https://github.com/ninja-build/ninja/blob/master/COPYING) -[GNU* FORTRAN Compiler](https://gcc.gnu.org/wiki/GFortran) | 7.4.0 or higher | Yes | apt | /usr/bin | [GNU General Public License, version 3](https://gcc.gnu.org/onlinedocs/gcc-7.5.0/gfortran/Copying.html) -[Intel(R) oneAPI DPC++ Compiler](https://software.intel.com/en-us/oneapi/dpc-compiler) | latest | No | *N/A* | *Installed by user* | [End User License Agreement for the Intel(R) Software Development Products](https://software.intel.com/en-us/license/eula-for-intel-software-development-products) -[hipSYCL](https://github.com/illuhad/hipSYCL/) | later than [2cfa530](https://github.com/illuhad/hipSYCL/commit/2cfa5303fd88b8f84e539b5bb6ed41e49c6d6118) | No | *N/A* | *Installed by user* | [BSD-2-Clause License ](https://github.com/illuhad/hipSYCL/blob/develop/LICENSE) -[Intel project for LLVM* technology binary for x86 CPU](https://github.com/intel/llvm/releases) | Daily builds (experimental) tested with [20200331](https://github.com/intel/llvm/releases/download/20200331/dpcpp-compiler.tar.gz) | No | *N/A* | *Installed by user* | [Apache License v2](https://github.com/intel/llvm/blob/sycl/sycl/LICENSE.TXT) -[Intel project for LLVM* technology source for NVIDIA GPU](https://github.com/intel/llvm/releases) | Daily source releases: tested with [20200421](https://github.com/intel/llvm/tree/20200421) | No | *N/A* | *Installed by user* | [Apache License v2](https://github.com/intel/llvm/blob/sycl/sycl/LICENSE.TXT) -[Intel(R) oneAPI Math Kernel Library](https://software.intel.com/en-us/oneapi/onemkl) | latest | Yes | apt | /opt/intel/inteloneapi/mkl | [Intel Simplified Software License](https://software.intel.com/en-us/license/intel-simplified-software-license) -[NVIDIA CUDA SDK](https://developer.nvidia.com/cublas) | 10.2 | No | *N/A* | *Installed by user* |[End User License Agreement](https://docs.nvidia.com/cuda/eula/index.html) -[AMD rocBLAS](https://rocblas.readthedocs.io/en/rocm-4.5.2/) | 4.5 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocBLAS/blob/develop/LICENSE.md) -[AMD rocRAND](https://github.com/ROCmSoftwarePlatform/rocRAND) | 5.1.0 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocRAND/blob/develop/LICENSE.txt) -[AMD rocSOLVER](https://github.com/ROCmSoftwarePlatform/rocSOLVER) | 5.0.0 | No | *N/A* | *Installed by user* |[AMD License](https://github.com/ROCmSoftwarePlatform/rocRAND/blob/develop/LICENSE.txt) -[NETLIB LAPACK](https://www.netlib.org/) | 3.7.1 | Yes | conan-community | ~/.conan/data or $CONAN_USER_HOME/.conan/data | [BSD like license](http://www.netlib.org/lapack/LICENSE.txt) -[Sphinx](https://www.sphinx-doc.org/en/master/) | 2.4.4 | Yes | pip | ~/.local/bin (or similar user local directory) | [BSD License](https://github.com/sphinx-doc/sphinx/blob/3.x/LICENSE) - -*conan-center: https://api.bintray.com/conan/conan/conan-center* - -*conan-community: https://api.bintray.com/conan/conan-community/conan* +Product | Supported Version | License +:--- | :--- | :--- +[CMake](https://cmake.org/download/) | 3.13 or higher | [The OSI-approved BSD 3-clause License](https://gitlab.kitware.com/cmake/cmake/raw/master/Copyright.txt) +[Ninja](https://ninja-build.org/) | 1.10.0 | [Apache License v2.0](https://github.com/ninja-build/ninja/blob/master/COPYING) +[GNU* FORTRAN Compiler](https://gcc.gnu.org/wiki/GFortran) | 7.4.0 or higher | [GNU General Public License, version 3](https://gcc.gnu.org/onlinedocs/gcc-7.5.0/gfortran/Copying.html) +[Intel(R) oneAPI DPC++ Compiler](https://software.intel.com/en-us/oneapi/dpc-compiler) | Latest | [End User License Agreement for the Intel(R) Software Development Products](https://software.intel.com/en-us/license/eula-for-intel-software-development-products) +[AdaptiveCpp](https://github.com/AdaptiveCpp/AdaptiveCpp) | Later than [2cfa530](https://github.com/AdaptiveCpp/AdaptiveCpp/commit/2cfa5303fd88b8f84e539b5bb6ed41e49c6d6118) | [BSD-2-Clause License ](https://github.com/AdaptiveCpp/AdaptiveCpp/blob/develop/LICENSE) +[oneAPI DPC++ Compiler binary for x86 CPU](https://github.com/intel/llvm/releases) | Daily builds | [Apache License v2](https://github.com/intel/llvm/blob/sycl/sycl/LICENSE.TXT) +[oneAPI DPC++ Compiler source for NVIDIA and AMD GPUs](https://github.com/intel/llvm) | Daily source releases | [Apache License v2](https://github.com/intel/llvm/blob/sycl/sycl/LICENSE.TXT) +[Intel(R) oneAPI Math Kernel Library](https://software.intel.com/en-us/oneapi/onemkl) | Latest | [Intel Simplified Software License](https://software.intel.com/en-us/license/intel-simplified-software-license) +[NVIDIA CUDA SDK](https://developer.nvidia.com/hpc-sdk) | 12.0 | [End User License Agreement](https://docs.nvidia.com/cuda/eula/index.html) +[AMD rocBLAS](https://github.com/ROCm/rocblas) | 4.5 | [AMD License](https://github.com/ROCm/rocBLAS/blob/develop/LICENSE.md) +[AMD rocRAND](https://github.com/ROCm/rocRAND) | 5.1.0 | [AMD License](https://github.com/ROCm/rocRAND/blob/develop/LICENSE.txt) +[AMD rocSOLVER](https://github.com/ROCm/rocSOLVER) | 5.0.0 | [AMD License](https://github.com/ROCm/rocSOLVER/blob/develop/LICENSE.md) +[AMD rocFFT](https://github.com/ROCm/rocFFT) | rocm-5.4.3 | [AMD License](https://github.com/ROCm/rocFFT/blob/rocm-5.4.3/LICENSE.md) +[NETLIB LAPACK](https://www.netlib.org/) | [5d4180c](https://github.com/Reference-LAPACK/lapack/commit/5d4180cf8288ae6ad9a771d18793d15bd0c5643c) | [BSD like license](http://www.netlib.org/lapack/LICENSE.txt) +[portBLAS](https://github.com/codeplaysoftware/portBLAS) | 0.1 | [Apache License v2.0](https://github.com/codeplaysoftware/portBLAS/blob/main/LICENSE) +[portFFT](https://github.com/codeplaysoftware/portFFT) | 0.1 | [Apache License v2.0](https://github.com/codeplaysoftware/portFFT/blob/main/LICENSE) --- @@ -453,7 +525,8 @@ Python | 3.6 or higher | No | *N/A* | *Pre-installed or Installed by user* | [PS - [About](https://oneapi-src.github.io/oneMKL/introduction.html) - Get Started - [Selecting a Compiler](https://oneapi-src.github.io/oneMKL/selecting_a_compiler.html) - - [Building the Project](https://oneapi-src.github.io/oneMKL/building_the_project.html) + - [Building the Project with DPC++](https://oneapi-src.github.io/oneMKL/building_the_project_with_dpcpp.html) + - [Building the Project with AdaptiveCpp](https://oneapi-src.github.io/oneMKL/building_the_project_with_adaptivecpp.html) - Developer Reference - [oneMKL Defined Datatypes](https://oneapi-src.github.io/oneMKL/onemkl-datatypes.html) - [Dense Linear Algebra](https://oneapi-src.github.io/oneMKL/domains/dense_linear_algebra.html) @@ -461,16 +534,36 @@ Python | 3.6 or higher | No | *N/A* | *Pre-installed or Installed by user* | [PS --- +## Governance + +The oneMKL Interfaces project is governed by the UXL Foundation and you can get involved in this project in multiple ways. It is possible to join the [Math Special Interest Group (SIG)](https://github.com/uxlfoundation/foundation/tree/main/math) meetings where the group discusses and demonstrates work using this project. Members can also join the Open Source and Specification Working Group meetings. + +You can also join the mailing lists for the [UXL Foundation](https://lists.uxlfoundation.org/g/main/subgroups) to be informed of when meetings are happening and receive the latest information and discussions. + +--- + ## Contributing -See [CONTRIBUTING](CONTRIBUTING.md) for more information. +You can contribute to this project and also contribute to [the specification for this project](https://spec.oneapi.io/versions/latest/elements/oneMKL/source/index.html). Please read the [CONTRIBUTING](CONTRIBUTING.md) page for more information. You can also contact oneMKL developers and maintainers via [UXL Foundation Slack](https://slack-invite.uxlfoundation.org/) using [#onemkl](https://uxlfoundation.slack.com/archives/onemkl) channel. + +For GitHub questions, issues, RFCs, or PRs you can contact maintainers via one of the following GitHub teams based on the topic: + +| GitHub team name | Description | +:-----------|:------------| +| @oneapi-src/onemkl-maintain | All oneMKL maintainers | +| @oneapi-src/onemkl-arch-write | oneMKL Architecture maintainers | +| @oneapi-src/onemkl-blas-write | oneMKL BLAS maintainers | +| @oneapi-src/onemkl-dft-write | oneMKL DFT maintainers | +| @oneapi-src/onemkl-lapack-write) | oneMKL LAPACK maintainers | +| @oneapi-src/onemkl-rng-write | oneMKL RNG maintainers | +| @oneapi-src/onemkl-sparse-write | oneMKL Sparse Algebra maintainers | +| @oneapi-src/onemkl-vm-write | oneMKL Vector Math maintainers | --- ## License - Distributed under the Apache license 2.0. See [LICENSE](LICENSE) for more -information. +Distributed under the Apache license 2.0. See [LICENSE](LICENSE) for more information. --- @@ -478,34 +571,30 @@ information. ### oneMKL -1. What is the difference between the following oneMKL items? +**Q: What is the difference between the following oneMKL items?** - The [oneAPI Specification for oneMKL](https://spec.oneapi.com/versions/latest/index.html) - The [oneAPI Math Kernel Library (oneMKL) Interfaces](https://github.com/oneapi-src/oneMKL) Project - The [Intel(R) oneAPI Math Kernel Library (oneMKL)](https://software.intel.com/content/www/us/en/develop/tools/oneapi/components/onemkl.html) Product -Answer: - +**A:** - The [oneAPI Specification for oneMKL](https://spec.oneapi.com/versions/latest/index.html) defines the DPC++ interfaces for performance math library functions. The oneMKL specification can evolve faster and more frequently than implementations of the specification. - The [oneAPI Math Kernel Library (oneMKL) Interfaces](https://github.com/oneapi-src/oneMKL) Project is an open source implementation of the specification. The project goal is to demonstrate how the DPC++ interfaces documented in the oneMKL specification can be implemented for any math library and work for any target hardware. While the implementation provided here may not yet be the full implementation of the specification, the goal is to build it out over time. We encourage the community to contribute to this project and help to extend support to multiple hardware targets and other math libraries. - The [Intel(R) oneAPI Math Kernel Library (oneMKL)](https://software.intel.com/content/www/us/en/develop/tools/oneapi/components/onemkl.html) product is the Intel product implementation of the specification (with DPC++ interfaces) as well as similar functionality with C and Fortran interfaces, and is provided as part of Intel® oneAPI Base Toolkit. It is highly optimized for Intel CPU and Intel GPU hardware. -### Conan - -1. I am behind a proxy. How can Conan download dependencies from external network? - - `~/.conan/conan.conf` has a `[proxies]` section where you can add the list of proxies. For details refer to [Conan proxy settings](https://docs.conan.io/en/latest/reference/config_files/conan.conf.html#proxies). - -2. I get an error while installing packages via APT through Conan. - ``` - dpkg: warning: failed to open configuration file '~/.dpkg.cfg' for reading: Permission denied - Setting up intel-oneapi-mkl-devel (2021.1-408.beta07) ... - E: Sub-process /usr/bin/dpkg returned an error code (1) - ``` - - Although your user session has permissions to install packages via `sudo apt`, it does not have permissions to update debian package configuration, which throws an error code 1, causing a failure in `conan install` command. - - The package is most likely installed correctly and can be verified by: - 1. Running the `conan install` command again. - 2. Checking `/opt/intel/inteloneapi` for `mkl` and/or `tbb` directories. +**Q: I'm trying to use oneMKL Interfaces in my project using `FetchContent`**, but I keep running into `ONEMKL::SYCL::SYCL target was not found` problem when I try to build the project. What should I do? + +**A:** +Make sure you set the compiler when you configure your project. +E.g. `cmake -Bbuild . -DCMAKE_CXX_COMPILER=icpx`. + +**Q: I'm trying to use oneMKL Interfaces in my project using `find_package(oneMKL)`.** I set oneMKL/oneTBB and Compiler environment first, then I built and installed oneMKL Interfaces, and finally I tried to build my project using installed oneMKL Interfaces (e.g. like this `cmake -Bbuild -GNinja -DCMAKE_CXX_COMPILER=icpx -DoneMKL_ROOT= .`) and I noticed that cmake includes installed oneMKL Interfaces headers as a system include which ends up as a lower priority than the installed oneMKL package includes which I set before for building oneMKL Interfaces. As a result, I get conflicts between oneMKL and installed oneMKL Interfaces headers. What should I do? + +**A:** +Having installed oneMKL Interfaces headers as `-I` instead on system includes (as `-isystem`) helps to resolve this problem. We use `INTERFACE_INCLUDE_DIRECTORIES` to add paths to installed oneMKL Interfaces headers (check `oneMKLTargets.cmake` in `lib/cmake` to find it). It's a known limitation that `INTERFACE_INCLUDE_DIRECTORIES` puts headers paths as system headers. To avoid that: +- Option 1: Use CMake >=3.25. In this case oneMKL Interfaces will be built with `EXPORT_NO_SYSTEM` property set to `true` and you won't see the issue. +- Option 2: If you use CMake < 3.25, set `PROPERTIES NO_SYSTEM_FROM_IMPORTED true` for your target. E.g: `set_target_properties(test PROPERTIES NO_SYSTEM_FROM_IMPORTED true)`. --- diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 000000000..480361d12 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,65 @@ +# Security Policy +As an open-source project, we understand the importance of and responsibility +for security. This Security policy outlines our guidelines and procedures for +ensuring the highest level of Security and trust for our users who consume +oneMKL Interfaces. + +## Supported Versions +We provide support for the [latest version][1] only. +The security vulnerabilities can be fixed in patch release on top of the latest version.Prior major releases might receive critical security fixes on a best-effort basis; however, we cannot guarantee that security fixes will get back-ported. + +## Report a Vulnerability +We are very grateful to the security researchers and users that report back +security vulnerabilities. We investigate every report thoroughly. +We strongly encourage you to report security vulnerabilities to us privately, +before disclosing them on public forums or opening a public GitHub issue. +Report a vulnerability to us in one of two ways: +* Open a draft [**GitHub Security Advisory**][2] +* Send e-mail to the following address: **security@uxlfoundation.org**. +Along with the report, please include the following info: + * A descriptive title. + * Your name and affiliation (if any). + * A description of the technical details of the vulnerabilities. + * A minimal example of the vulnerability so we can reproduce your findings. + * An explanation of who can exploit this vulnerability, and what they gain + when doing so. + * Whether this vulnerability is public or known to third parties. If it is, + please provide details. + +### When Should I Report a Vulnerability? +* You think you discovered a potential security vulnerability in oneMKL Interfaces. +* You are unsure how the potential vulnerability affects oneMKL Interfaces. +* You think you discovered a vulnerability in another project or 3rd party +component on which oneMKL Interfaces depends. If the issue is not fixed in the 3rd party +component, try to report directly there first. + +### When Should I NOT Report a Vulnerability? +* You got an automated scan hit and are unable to provide details. +* You need help using oneMKL Interfaces for security. +* You need help applying security-related updates. +* Your issue is not security-related. + +## Security Reports Review Process +Our goal is to respond quickly to your inquiry, and to coordinate a fix and +disclosure with you. All confirmed security vulnerabilities will be addressed +according to severity level and impact on oneMKL Interfaces. Normally, security issues +are fixed in the next planned release. + +## Disclosure Policy +We will publish security advisories using the +[**GitHub Security Advisories feature**][3] +to keep our community well-informed, and will credit you for your findings +unless you prefer to stay anonymous. We request that you refrain from +exploiting the vulnerability or making it public before the official disclosure. + +We will disclose the vulnerabilities and/or bugs as soon as possible once +mitigation is implemented and available. + +## Feedback on This Policy +If you have any suggestions on how this Policy could be improved, please submit +an issue or a pull request to this repository. Please **do not** report +potential vulnerabilities or security flaws via a pull request. + +[1]: https://github.com/oneapi-src/oneMKL/releases/latest +[2]: https://github.com/oneapi-src/oneMKL/security/advisories/new +[3]: https://github.com/oneapi-src/oneMKL/security/advisories \ No newline at end of file diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 3d0c12425..df7d2fc4c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -17,7 +17,11 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== -install(FILES FindMKL.cmake - FindCompiler.cmake +install(FILES FindCompiler.cmake DESTINATION "lib/cmake/${PROJECT_NAME}" ) +if(ENABLE_MKLGPU_BACKEND OR ENABLE_MKLCPU_BACKEND) + install(FILES mkl/MKLConfig.cmake + DESTINATION "lib/cmake/${PROJECT_NAME}" + ) +endif() diff --git a/cmake/FindCompiler.cmake b/cmake/FindCompiler.cmake index 410756c2f..265719bf0 100644 --- a/cmake/FindCompiler.cmake +++ b/cmake/FindCompiler.cmake @@ -18,21 +18,25 @@ #=============================================================================== include_guard() - include(CheckCXXCompilerFlag) include(FindPackageHandleStandardArgs) - check_cxx_compiler_flag("-fsycl" is_dpcpp) if(is_dpcpp) # Workaround for internal compiler error during linking if -fsycl is used get_filename_component(SYCL_BINARY_DIR ${CMAKE_CXX_COMPILER} DIRECTORY) - find_library(SYCL_LIBRARY NAMES sycl PATHS "${SYCL_BINARY_DIR}/../lib") + find_library(SYCL_LIBRARY NAMES sycl PATHS "${SYCL_BINARY_DIR}/../lib" "${SYCL_BINARY_DIR}/lib" ENV LIBRARY_PATH ENV PATH) + if(NOT SYCL_LIBRARY) + message(FATAL_ERROR "SYCL library is not found in ${SYCL_BINARY_DIR}/../lib, PATH, and LIBRARY_PATH") + endif() add_library(ONEMKL::SYCL::SYCL INTERFACE IMPORTED) if(UNIX) set(UNIX_INTERFACE_COMPILE_OPTIONS -fsycl) set(UNIX_INTERFACE_LINK_OPTIONS -fsycl) + # Check if the Nvidia target is supported. PortFFT uses this for choosing default configuration. + check_cxx_compiler_flag("-fsycl -fsycl-targets=nvptx64-nvidia-cuda" dpcpp_supports_nvptx64) + if(ENABLE_CURAND_BACKEND OR ENABLE_CUSOLVER_BACKEND) list(APPEND UNIX_INTERFACE_COMPILE_OPTIONS -fsycl-targets=nvptx64-nvidia-cuda -fsycl-unnamed-lambda) diff --git a/cmake/FindMKL.cmake b/cmake/FindMKL.cmake deleted file mode 100644 index 2af614b62..000000000 --- a/cmake/FindMKL.cmake +++ /dev/null @@ -1,115 +0,0 @@ -#=============================================================================== -# Copyright 2020-2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# -# -# SPDX-License-Identifier: Apache-2.0 -#=============================================================================== - -include_guard() -set(MKL_SYCL mkl_sycl) -set(MKL_IFACE mkl_intel_ilp64) -set(MKL_SEQ mkl_sequential) -set(MKL_TBB mkl_tbb_thread) -set(MKL_CORE mkl_core) - -set(MKL_C ${MKL_IFACE}) - -if(ENABLE_MKLCPU_THREAD_TBB) - find_package(TBB REQUIRED) - list(APPEND MKL_C ${MKL_TBB}) -else() - list(APPEND MKL_C ${MKL_SEQ}) -endif() - -list(APPEND MKL_C ${MKL_CORE}) - -if(ENABLE_MKLGPU_BACKEND) - set(USE_DPCPP_API ON) -endif() - -if (ENABLE_MKLCPU_BACKEND OR ENABLE_MKLGPU_BACKEND) - if(USE_DPCPP_API) - list(APPEND MKL_LIBRARIES ${MKL_SYCL}) - endif() - list(APPEND MKL_LIBRARIES ${MKL_C}) -endif() - -include(FindPackageHandleStandardArgs) -foreach(lib ${MKL_LIBRARIES}) - find_library(${lib}_file NAMES ${lib} - HINTS $ENV{MKLROOT} ${MKL_ROOT} - PATH_SUFFIXES lib/intel64) - find_package_handle_standard_args(MKL - REQUIRED_VARS ${lib}_file - VERSION_VAR MKL_VERSION) -endforeach() - -get_filename_component(MKL_LIB_DIR ${mkl_core_file} DIRECTORY) - -find_path(MKL_INCLUDE mkl.h - HINTS $ENV{MKLROOT} ${MKL_ROOT} - PATH_SUFFIXES include) - -file(READ "${MKL_INCLUDE}/mkl_version.h" mkl_version_h) -string(REGEX MATCH "INTEL_MKL_VERSION ([0-9]*)" _ ${mkl_version_h}) -set(MKL_VERSION ${CMAKE_MATCH_1}) - -if(${CMAKE_SIZEOF_VOID_P} EQUAL 8 OR USE_DPCPP_API) - set(MKL_COPT "-DMKL_ILP64") -else() - set(MKL_COPT "") -endif() -list(APPEND MKL_COPT "-DINTEL_MKL_VERSION=${MKL_VERSION}") - -if(UNIX) - list(APPEND MKL_LINK_PREFIX "-Wl,-rpath,${MKL_LIB_DIR}") - list(APPEND MKL_LINK_PREFIX "-L${MKL_LIB_DIR}") - set(LIB_PREFIX "-l") - set(LIB_SUFFIX "") -else() - set(LIB_PREFIX "${MKL_LIB_DIR}/") - set(LIB_SUFFIX ".lib") -endif() - -if (ENABLE_MKLCPU_BACKEND OR ENABLE_MKLGPU_BACKEND) - set(MKL_LINK_C ${MKL_LINK_PREFIX}) - foreach(lib ${MKL_C}) - list(APPEND MKL_LINK_C ${LIB_PREFIX}${lib}${LIB_SUFFIX}) - endforeach() - if(ENABLE_MKLCPU_THREAD_TBB) - list(APPEND MKL_LINK_C ${TBB_LINK}) - endif() - if(USE_DPCPP_API) - find_package(OpenCL QUIET) - # Try to find OpenCL library in the environment - if(${OpenCL_LIBRARY} STREQUAL "OpenCL_LIBRARY-NOTFOUND") - find_library(OPENCL_LIBNAME NAMES libOpenCL.so OpenCL.lib OpenCL HINTS ENV LIBRARY_PATH ENV LD_LIBRARY_PATH ENV LIB ENV PATH) - else() - set(OPENCL_LIBNAME ${OpenCL_LIBRARY}) - endif() - find_package_handle_standard_args(MKL REQUIRED_VARS OPENCL_LIBNAME) - set(MKL_LINK_SYCL ${MKL_LINK_PREFIX} ${LIB_PREFIX}${MKL_SYCL}${LIB_SUFFIX} ${MKL_LINK_C} ${OPENCL_LIBNAME} ) - endif() -endif() - -if (USE_DPCPP_API) - find_package_handle_standard_args(MKL - REQUIRED_VARS MKL_INCLUDE MKL_COPT MKL_LINK_SYCL - VERSION_VAR MKL_VERSION) -else(ENABLE_MKLCPU_BACKEND) - find_package_handle_standard_args(MKL - REQUIRED_VARS MKL_INCLUDE MKL_COPT MKL_LINK_C - VERSION_VAR MKL_VERSION) -endif() diff --git a/cmake/FindTBB.cmake b/cmake/FindTBB.cmake deleted file mode 100644 index 61934c187..000000000 --- a/cmake/FindTBB.cmake +++ /dev/null @@ -1,53 +0,0 @@ -#=============================================================================== -# Copyright 2020-2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# -# -# SPDX-License-Identifier: Apache-2.0 -#=============================================================================== - -include_guard() - -#Workaround for soname problem -if(UNIX) - set(TBB_LIBNAME libtbb.so) -else() - set(TBB_LIBNAME tbb.lib) -endif() - -find_path(TBB_LIB_DIR ${TBB_LIBNAME} - HINTS $ENV{TBBROOT} $ENV{MKLROOT} ${MKL_ROOT} ${TBB_ROOT} - PATH_SUFFIXES "lib" "lib/intel64/gcc4.4" "lib/intel64/gcc4.8" - "../tbb/lib/intel64/gcc4.4" "../tbb/lib/intel64/gcc4.8" - "../../tbb/latest/lib/intel64/gcc4.8" - "../tbb/lib/intel64/vc14" "lib/intel64/vc14" -) - -find_library(TBB_LIBRARIES NAMES tbb - HINTS $ENV{TBBROOT} $ENV{MKLROOT} ${MKL_ROOT} ${TBB_ROOT} - PATH_SUFFIXES "lib" "lib/intel64/gcc4.4" "lib/intel64/gcc4.8" - "../tbb/lib/intel64/gcc4.4" "../tbb/lib/intel64/gcc4.8" - "../../tbb/latest/lib/intel64/gcc4.8" - "../tbb/lib/intel64/vc14" "lib/intel64/vc14" - ) - -#Workaround for ref problem -if(UNIX) - set(TBB_LINK "-Wl,-rpath,${TBB_LIB_DIR} -L${TBB_LIB_DIR} -ltbb") -else() - set(TBB_LINK ${TBB_LIBRARIES}) -endif() -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(TBB REQUIRED_VARS TBB_LIBRARIES TBB_LINK) - diff --git a/cmake/FindcuBLAS.cmake b/cmake/FindcuBLAS.cmake index b7ea6b911..c26a62f6b 100644 --- a/cmake/FindcuBLAS.cmake +++ b/cmake/FindcuBLAS.cmake @@ -25,6 +25,7 @@ find_path(OPENCL_INCLUDE_DIR CL/cl.h OpenCL/cl.h HINTS ${OPENCL_INCLUDE_DIR} ${SYCL_BINARY_DIR}/../include/sycl/ +${SYCL_BINARY_DIR}/../../include/sycl/ ) # this is work around to avoid duplication half creation in both cuda and SYCL add_compile_definitions(CUDA_NO_HALF) diff --git a/cmake/FindrocBLAS.cmake b/cmake/FindrocBLAS.cmake deleted file mode 100644 index bbcb56664..000000000 --- a/cmake/FindrocBLAS.cmake +++ /dev/null @@ -1,60 +0,0 @@ -#========================================================================== -# Copyright 2020-2022 Intel Corporation -# Copyright (C) Codeplay Software Limited -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# For your convenience, a copy of the License has been included in this -# repository. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -#========================================================================= - -if(NOT DEFINED HIP_PATH) - if(NOT DEFINED ENV{HIP_PATH}) - set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed") - else() - set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed") - endif() -endif() - -set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) -list(APPEND CMAKE_PREFIX_PATH - "${HIP_PATH}/lib/cmake" - "${HIP_PATH}/../lib/cmake" -) - -find_package(HIP QUIET) -find_package(rocblas REQUIRED) - -# this is work around to avoid duplication half creation in both HIP and SYCL -add_compile_definitions(HIP_NO_HALF) - -find_package(Threads REQUIRED) - -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(rocBLAS - REQUIRED_VARS - HIP_INCLUDE_DIRS - HIP_LIBRARIES - ROCBLAS_INCLUDE_DIR - ROCBLAS_LIBRARIES -) -# OPENCL_INCLUDE_DIR -if(NOT TARGET ONEMKL::rocBLAS::rocBLAS) - add_library(ONEMKL::rocBLAS::rocBLAS SHARED IMPORTED) - set_target_properties(ONEMKL::rocBLAS::rocBLAS PROPERTIES - IMPORTED_LOCATION "${HIP_PATH}/../rocblas/lib/librocblas.so" - INTERFACE_INCLUDE_DIRECTORIES "${ROCBLAS_INCLUDE_DIR};${HIP_INCLUDE_DIRS};" - INTERFACE_LINK_LIBRARIES "Threads::Threads;${ROCBLAS_LIBRARIES};" - ) - -endif() diff --git a/cmake/FindrocRAND.cmake b/cmake/FindrocRAND.cmake deleted file mode 100644 index 9d8cb7c35..000000000 --- a/cmake/FindrocRAND.cmake +++ /dev/null @@ -1,55 +0,0 @@ -#=============================================================================== -# Copyright 2022 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# -# -# SPDX-License-Identifier: Apache-2.0 -#=============================================================================== - -if(NOT DEFINED HIP_PATH) - if(NOT DEFINED ENV{HIP_PATH}) - set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed") - else() - set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed") - endif() -endif() - -set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) -list(APPEND CMAKE_PREFIX_PATH - "${HIP_PATH}/lib/cmake" - "${HIP_PATH}/../lib/cmake" - "${HIP_PATH}/../lib/cmake/rocrand/rocrand") - -find_package(rocrand REQUIRED) -find_package(hip QUIET) - -# this is work around to avoid duplication half creation in both hip and SYCL -add_compile_definitions(HIP_NO_HALF) - -find_package(Threads REQUIRED) - -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(rocRAND - REQUIRED_VARS - HIP_INCLUDE_DIRS - rocrand_INCLUDE_DIR - rocrand_LIBRARIES) - -if(NOT TARGET ONEMKL::rocRAND::rocRAND) - add_library(ONEMKL::rocRAND::rocRAND SHARED IMPORTED) - set_target_properties(ONEMKL::rocRAND::rocRAND PROPERTIES - IMPORTED_LOCATION "${HIP_PATH}/../rocrand/lib/librocrand.so" - INTERFACE_INCLUDE_DIRECTORIES "${rocrand_INCLUDE_DIR};${HIP_INCLUDE_DIRS};" - INTERFACE_LINK_LIBRARIES "Threads::Threads;hip::host;${rocrand_LIBRARIES};") -endif() diff --git a/cmake/FindrocSOLVER.cmake b/cmake/FindrocSOLVER.cmake deleted file mode 100644 index c1145443e..000000000 --- a/cmake/FindrocSOLVER.cmake +++ /dev/null @@ -1,56 +0,0 @@ -#=============================================================================== -# Copyright 2022 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# -# -# SPDX-License-Identifier: Apache-2.0 -#=============================================================================== - -if(NOT DEFINED HIP_PATH) - if(NOT DEFINED ENV{HIP_PATH}) - set(HIP_PATH "/opt/rocm/hip" CACHE PATH "Path to which HIP has been installed") - else() - set(HIP_PATH $ENV{HIP_PATH} CACHE PATH "Path to which HIP has been installed") - endif() -endif() - -set(CMAKE_MODULE_PATH "${HIP_PATH}/cmake" ${CMAKE_MODULE_PATH}) -list(APPEND CMAKE_PREFIX_PATH - "${HIP_PATH}/lib/cmake" - "${HIP_PATH}/../lib/cmake" - "${HIP_PATH}/../lib/cmake/rocsolver") - -find_package(HIP QUIET) -find_package(rocsolver REQUIRED) - -# this is work around to avoid duplication half creation in both hip and SYCL -add_compile_definitions(HIP_NO_HALF) - -find_package(Threads REQUIRED) - -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(rocSOLVER - REQUIRED_VARS - HIP_INCLUDE_DIRS - rocsolver_INCLUDE_DIR - rocsolver_LIBRARIES) - -if(NOT TARGET ONEMKL::rocSOLVER::rocSOLVER) - add_library(ONEMKL::rocSOLVER::rocSOLVER SHARED IMPORTED) - set_target_properties(ONEMKL::rocSOLVER::rocSOLVER PROPERTIES - IMPORTED_LOCATION "${HIP_PATH}/../rocsolver/lib/librocsolver.so" - INTERFACE_INCLUDE_DIRECTORIES "${rocsolver_INCLUDE_DIR};${HIP_INCLUDE_DIRS};" - INTERFACE_LINK_LIBRARIES "Threads::Threads;hip::host;${rocsolver_LIBRARIES};") -endif() - diff --git a/cmake/WarningsUtils.cmake b/cmake/WarningsUtils.cmake new file mode 100644 index 000000000..3b5f76afb --- /dev/null +++ b/cmake/WarningsUtils.cmake @@ -0,0 +1,48 @@ +#=============================================================================== +# Copyright Codeplay Software Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +include_guard(GLOBAL) + +add_library(onemkl_warnings INTERFACE) + +set(ONEMKL_WARNINGS "") + +include(CheckCXXCompilerFlag) +macro(add_warning flag) + check_cxx_compiler_flag(${flag} IS_SUPPORTED) + if(${IS_SUPPORTED}) + list(APPEND ONEMKL_WARNINGS ${flag}) + else() + message(WARNING "Compiler does not support ${flag}") + endif() +endmacro() + +add_warning("-Wall") +add_warning("-Wextra") +add_warning("-Wshadow") +add_warning("-Wconversion") +add_warning("-Wpedantic") + +message(VERBOSE "Domains with warnings enabled use: ${ONEMKL_WARNINGS}") + +# The onemkl_warnings target can be linked to any other target to enable warnings. +target_compile_options(onemkl_warnings INTERFACE ${ONEMKL_WARNINGS}) + +# Add the library to install package +install(TARGETS onemkl_warnings EXPORT oneMKLTargets) diff --git a/cmake/mkl/MKLConfig.cmake b/cmake/mkl/MKLConfig.cmake new file mode 100644 index 000000000..7614288b3 --- /dev/null +++ b/cmake/mkl/MKLConfig.cmake @@ -0,0 +1,1158 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +#=================================================================== +# CMake Config file for Intel(R) oneAPI Math Kernel Library (oneMKL) +#=============================================================================== + +#=============================================================================== +# Input parameters +#================= +#------------- +# Main options +#------------- +# MKL_ROOT: oneMKL root directory (May be required for non-standard install locations. Optional otherwise.) +# Default: use location from MKLROOT environment variable or /../../../ if MKLROOT is not defined +# MKL_ARCH +# Values: ia32 intel64 +# Default: intel64 +# MKL_LINK +# Values: static, dynamic, sdl +# Default: dynamic +# Exceptions:- DPC++ doesn't support sdl +# MKL_THREADING +# Values: sequential, +# intel_thread (Intel OpenMP), +# gnu_thread (GNU OpenMP), +# pgi_thread (PGI OpenMP) [PGI support is deprecated], +# tbb_thread +# Default: intel_thread +# Exceptions:- DPC++ defaults to oneTBB, PGI compiler on Windows defaults to pgi_thread +# MKL_INTERFACE (for MKL_ARCH=intel64 only) +# Values: lp64, ilp64 +# GNU or INTEL interface will be selected based on Compiler. +# Default: ilp64 +# MKL_MPI +# Values: intelmpi, mpich, openmpi, msmpi, mshpc +# Default: intelmpi +#----------------------------------- +# Special options (OFF by default) +#----------------------------------- +# ENABLE_BLAS95: Enables BLAS Fortran95 API +# ENABLE_LAPACK95: Enables LAPACK Fortran95 API +# ENABLE_BLACS: Enables cluster BLAS library +# ENABLE_CDFT: Enables cluster DFT library +# ENABLE_CPARDISO: Enables cluster PARDISO functionality +# ENABLE_SCALAPACK: Enables cluster LAPACK library +# ENABLE_OMP_OFFLOAD: Enables OpenMP Offload functionality +# +#================== +# Output parameters +#================== +# MKL_ROOT +# oneMKL root directory. +# MKL_INCLUDE +# Use of target_include_directories() is recommended. +# INTERFACE_INCLUDE_DIRECTORIES property is set on mkl_core and mkl_rt libraries. +# Alternatively, this variable can be used directly (not recommended as per Modern CMake) +# MKL_ENV +# Provides all environment variables based on input parameters. +# Currently useful for mkl_rt linking and BLACS on Windows. +# Must be set as an ENVIRONMENT property. +# Example: +# add_test(NAME mytest COMMAND myexe) +# if(MKL_ENV) +# set_tests_properties(mytest PROPERTIES ENVIRONMENT "${MKL_ENV}") +# endif() +# +# MKL:: +# IMPORTED targets to link oneMKL libraries individually or when using a custom link-line. +# mkl_core and mkl_rt have INTERFACE_* properties set to them. +# Please refer to Intel(R) oneMKL Link Line Advisor for help with linking. +# +# Below INTERFACE targets provide full link-lines for direct use. +# Example: +# target_link_options( PUBLIC $) +# +# MKL::MKL +# Link line for C and Fortran API +# MKL::MKL_SYCL +# Link line for DPC++ API +# +# Note: For Device API, library linking is not required. +# Compile options can be added from the INTERFACE_COMPILE_OPTIONS property on MKL::MKL_SYCL +# Include directories can be added from the INTERFACE_INCLUDE_DIRECTORIES property on MKL::MKL_SYCL +# +# Note: Output parameters' and targets' availability can change +# based on Input parameters and application project languages. +#=============================================================================== + +include_guard() + +if(NOT MKL_LIBRARIES) + +function(mkl_message MSG_MODE MSG_TEXT) + if(MSG_MODE STREQUAL "FATAL_ERROR") + message(${MSG_MODE} ${MSG_TEXT}) + else() + if(NOT MKL_FIND_QUIETLY) + message(${MSG_MODE} ${MSG_TEXT}) + endif() + endif() +endfunction() + +if(CMAKE_VERSION VERSION_LESS "3.13") + mkl_message(FATAL_ERROR "The minimum supported CMake version is 3.13. You are running version ${CMAKE_VERSION}.") +endif() + +# Set CMake policies for well-defined behavior across CMake versions +cmake_policy(SET CMP0011 NEW) +cmake_policy(SET CMP0057 NEW) + +# Project Languages +get_property(languages GLOBAL PROPERTY ENABLED_LANGUAGES) +list(APPEND MKL_LANGS C CXX Fortran) +foreach(lang ${languages}) + if(${lang} IN_LIST MKL_LANGS) + list(APPEND CURR_LANGS ${lang}) + endif() +endforeach() +list(REMOVE_DUPLICATES CURR_LANGS) + +option(ENABLE_BLAS95 "Enables BLAS Fortran95 API" OFF) +option(ENABLE_LAPACK95 "Enables LAPACK Fortran95 API" OFF) +option(ENABLE_BLACS "Enables cluster BLAS library" OFF) +option(ENABLE_CDFT "Enables cluster DFT library" OFF) +option(ENABLE_CPARDISO "Enables cluster PARDISO functionality" OFF) +option(ENABLE_SCALAPACK "Enables cluster LAPACK library" OFF) +option(ENABLE_OMP_OFFLOAD "Enables OpenMP Offload functionality" OFF) + +# Use MPI if any of these are enabled +if(ENABLE_BLACS OR ENABLE_CDFT OR ENABLE_SCALAPACK OR ENABLE_CPARDISO) + set(USE_MPI ON) +endif() + +# Check Parameters +function(define_param TARGET_PARAM DEFAULT_PARAM SUPPORTED_LIST) + if(NOT DEFINED ${TARGET_PARAM} AND NOT DEFINED ${DEFAULT_PARAM}) + mkl_message(STATUS "${TARGET_PARAM}: Undefined") + elseif(NOT DEFINED ${TARGET_PARAM} AND DEFINED ${DEFAULT_PARAM}) + set(${TARGET_PARAM} "${${DEFAULT_PARAM}}" CACHE STRING "Choose ${TARGET_PARAM} options are: ${${SUPPORTED_LIST}}") + foreach(opt ${${DEFAULT_PARAM}}) + set(STR_LIST "${STR_LIST} ${opt}") + endforeach() + mkl_message(STATUS "${TARGET_PARAM}: None, set to `${STR_LIST}` by default") + elseif(${SUPPORTED_LIST}) + set(ITEM_FOUND 1) + foreach(opt ${${TARGET_PARAM}}) + if(NOT ${opt} IN_LIST ${SUPPORTED_LIST}) + set(ITEM_FOUND 0) + endif() + endforeach() + if(ITEM_FOUND EQUAL 0) + foreach(opt ${${SUPPORTED_LIST}}) + set(STR_LIST "${STR_LIST} ${opt}") + endforeach() + if(${ARGC} EQUAL 3) + mkl_message(FATAL_ERROR "Invalid ${TARGET_PARAM} `${${TARGET_PARAM}}`, options are: ${STR_LIST}") + elseif(${ARGC} EQUAL 4) + mkl_message(${ARGV3} "Invalid ${TARGET_PARAM} `${${TARGET_PARAM}}`, options are: ${STR_LIST}") + set(${TARGET_PARAM} "" PARENT_SCOPE) + endif() + else() + mkl_message(STATUS "${TARGET_PARAM}: ${${TARGET_PARAM}}") + endif() + else() + mkl_message(STATUS "${TARGET_PARAM}: ${${TARGET_PARAM}}") + endif() +endfunction() + +macro(check_required_vars) + foreach(var IN ITEMS ${ARGV}) + if(NOT ${var}) + set(${CMAKE_FIND_PACKAGE_NAME}_NOT_FOUND_MESSAGE "The required variable ${var} has an invalid value \"${${var}}\".") + set(${CMAKE_FIND_PACKAGE_NAME}_FOUND FALSE) + return() + endif() + endforeach() +endmacro() + +#================ +# Compiler checks +#================ + +if(CMAKE_C_COMPILER) + get_filename_component(C_COMPILER_NAME ${CMAKE_C_COMPILER} NAME) +endif() +if(CMAKE_CXX_COMPILER) + get_filename_component(CXX_COMPILER_NAME ${CMAKE_CXX_COMPILER} NAME) +endif() +if(CMAKE_Fortran_COMPILER) + get_filename_component(Fortran_COMPILER_NAME ${CMAKE_Fortran_COMPILER} NAME) +endif() + +# Determine Compiler Family +if(CXX_COMPILER_NAME STREQUAL "dpcpp" OR CXX_COMPILER_NAME STREQUAL "dpcpp.exe" + OR CXX_COMPILER_NAME STREQUAL "icpx" OR CXX_COMPILER_NAME STREQUAL "icx.exe") + set(SYCL_COMPILER ON) +endif() +if(C_COMPILER_NAME MATCHES "^clang" OR CXX_COMPILER_NAME MATCHES "^clang") + set(CLANG_COMPILER ON) +endif() +if(CMAKE_C_COMPILER_ID STREQUAL "PGI" OR CMAKE_CXX_COMPILER_ID STREQUAL "PGI" OR CMAKE_Fortran_COMPILER_ID STREQUAL "PGI" + OR CMAKE_C_COMPILER_ID STREQUAL "NVHPC" OR CMAKE_CXX_COMPILER_ID STREQUAL "NVHPC" + OR CMAKE_Fortran_COMPILER_ID STREQUAL "NVHPC") # PGI 22.9 + mkl_message(WARNING "PGI support is deprecated and will be removed in the oneMKL 2025.0 release.") + set(PGI_COMPILER ON) +elseif(CMAKE_C_COMPILER_ID STREQUAL "Intel" OR CMAKE_CXX_COMPILER_ID STREQUAL "Intel" OR CMAKE_Fortran_COMPILER_ID STREQUAL "Intel" + OR CMAKE_C_COMPILER_ID STREQUAL "IntelLLVM" OR CMAKE_CXX_COMPILER_ID STREQUAL "IntelLLVM" OR CMAKE_Fortran_COMPILER_ID STREQUAL "IntelLLVM") + set(INTEL_COMPILER ON) +else() + if(CMAKE_C_COMPILER_ID STREQUAL "GNU" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU") + set(GNU_C_COMPILER ON) + endif() + if(CMAKE_Fortran_COMPILER_ID STREQUAL "GNU") + set(GNU_Fortran_COMPILER ON) + endif() +endif() +# CMake identifies IntelLLVM compilers only after 3.20 +if(NOT INTEL_COMPILER) + if(C_COMPILER_NAME STREQUAL "icx" OR C_COMPILER_NAME STREQUAL "icx.exe" + OR CXX_COMPILER_NAME STREQUAL "icpx" OR CXX_COMPILER_NAME STREQUAL "icx.exe" + OR Fortran_COMPILER_NAME STREQUAL "ifx" OR Fortran_COMPILER_NAME STREQUAL "ifx.exe") + set(INTEL_COMPILER ON) + endif() +endif() +# CMake supports IntelLLVM compilers only after 3.25.2 +if(CMAKE_VERSION VERSION_LESS "3.25.2") + if(C_COMPILER_NAME STREQUAL "icx" OR C_COMPILER_NAME STREQUAL "icx.exe" OR CXX_COMPILER_NAME STREQUAL "icx.exe") + list(APPEND INTEL_LLVM_COMPILERS_IN_USE "icx") + endif() + if(CXX_COMPILER_NAME STREQUAL "icpx") + list(APPEND INTEL_LLVM_COMPILERS_IN_USE "icpx") + endif() + if(Fortran_COMPILER_NAME STREQUAL "ifx" OR Fortran_COMPILER_NAME STREQUAL "ifx.exe") + list(APPEND INTEL_LLVM_COMPILERS_IN_USE "ifx") + endif() + if(INTEL_LLVM_COMPILERS_IN_USE) + list(JOIN INTEL_LLVM_COMPILERS_IN_USE ", " INTEL_LLVM_COMPILERS_IN_USE_COMMA) + mkl_message(STATUS "Upgrade to CMake version 3.25.2 or later for native support of Intel compiler(s) ${INTEL_LLVM_COMPILERS_IN_USE_COMMA}. You are running version ${CMAKE_VERSION}.") + endif() +endif() + +if(USE_MPI AND (C_COMPILER_NAME MATCHES "^mpi" OR Fortran_COMPILER_NAME MATCHES "^mpi")) + set(USE_MPI_SCRIPT ON) +endif() + +#================ + +#================ +# System-specific +#================ + +# Extensions +if(UNIX) + set(LIB_PREFIX "lib") + set(LIB_EXT ".a") + set(DLL_EXT ".so") + if(APPLE) + set(DLL_EXT ".dylib") + endif() + set(LINK_PREFIX "-l") + set(LINK_SUFFIX "") +else() + set(LIB_PREFIX "") + set(LIB_EXT ".lib") + set(DLL_EXT "_dll.lib") + set(LINK_PREFIX "") + set(LINK_SUFFIX ".lib") +endif() + +#================ + +#============= +# Setup oneMKL +#============= + +# Set MKL_ROOT directory +if(NOT DEFINED MKL_ROOT) + if(DEFINED ENV{MKLROOT}) + set(MKL_ROOT $ENV{MKLROOT}) + # Verify that the version in MKL_ROOT is the same as MKL_VERSION + find_file(MKL_VERSION_H mkl_version.h + HINTS ${MKL_ROOT} + PATH_SUFFIXES include + NO_DEFAULT_PATH) + check_required_vars(MKL_VERSION_H) + file(READ ${MKL_VERSION_H} MKL_VERSION_H_CONTENT) + string(REGEX MATCH "__INTEL_MKL__ +([0-9]+)" MKL_VERSION_INFO ${MKL_VERSION_H_CONTENT}) + set(MKL_ROOT_MAJOR_VERSION ${CMAKE_MATCH_1}) + string(REGEX MATCH "__INTEL_MKL_UPDATE__ +([0-9]+)" MKL_VERSION_INFO ${MKL_VERSION_H_CONTENT}) + set(MKL_ROOT_UPDATE_VERSION ${CMAKE_MATCH_1}) + set(MKL_ROOT_VERSION ${MKL_ROOT_MAJOR_VERSION}.${MKL_ROOT_UPDATE_VERSION}) + if(NOT MKL_ROOT_VERSION VERSION_EQUAL ${CMAKE_FIND_PACKAGE_NAME}_VERSION) + mkl_message(FATAL_ERROR "oneMKL ${MKL_ROOT_VERSION} specified by the environment variable MKLROOT \ + mismatches the found version ${${CMAKE_FIND_PACKAGE_NAME}_VERSION} \ + indicated by ${CMAKE_CURRENT_LIST_DIR}/MKLConfigVersion.cmake") + endif() + else() + get_filename_component(MKL_CMAKE_PATH "${CMAKE_CURRENT_LIST_DIR}" REALPATH) + get_filename_component(MKL_ROOT "${MKL_CMAKE_PATH}/../../../" ABSOLUTE) + endif() +endif() +string(REPLACE "\\" "/" MKL_ROOT ${MKL_ROOT}) +check_required_vars(MKL_ROOT) +mkl_message(STATUS "${CMAKE_FIND_PACKAGE_NAME}_VERSION: ${${CMAKE_FIND_PACKAGE_NAME}_VERSION}") +mkl_message(STATUS "MKL_ROOT: ${MKL_ROOT}") + +# Set target system architecture +if(SYCL_COMPILER) + set(DEFAULT_MKL_SYCL_ARCH intel64) + set(MKL_SYCL_ARCH_LIST intel64) + if(NOT DEFINED MKL_SYCL_ARCH) + set(MKL_SYCL_ARCH ${MKL_ARCH}) + endif() + define_param(MKL_SYCL_ARCH DEFAULT_MKL_SYCL_ARCH MKL_SYCL_ARCH_LIST STATUS) + if(NOT MKL_SYCL_ARCH) + set(SYCL_COMPILER OFF) + mkl_message(STATUS "MKL::MKL_SYCL target will not be available.") + endif() +endif() +set(DEFAULT_MKL_ARCH intel64) +if(PGI_COMPILER OR ENABLE_OMP_OFFLOAD OR USE_MPI) + set(MKL_ARCH_LIST intel64) +else() + set(MKL_ARCH_LIST ia32 intel64) +endif() +define_param(MKL_ARCH DEFAULT_MKL_ARCH MKL_ARCH_LIST) +check_required_vars(MKL_ARCH) +if(MKL_ARCH STREQUAL "ia32") + set(MKL_ARCH_DIR "32") +else() + set(MKL_ARCH_DIR "") +endif() + +# Define MKL_LINK +if(SYCL_COMPILER) + set(DEFAULT_MKL_SYCL_LINK dynamic) + set(MKL_SYCL_LINK_LIST static dynamic) + if(NOT DEFINED MKL_SYCL_LINK) + set(MKL_SYCL_LINK ${MKL_LINK}) + endif() + define_param(MKL_SYCL_LINK DEFAULT_MKL_SYCL_LINK MKL_SYCL_LINK_LIST STATUS) + if(NOT MKL_SYCL_LINK) + set(SYCL_COMPILER OFF) + mkl_message(STATUS "MKL::MKL_SYCL target will not be available.") + endif() +endif() +set(DEFAULT_MKL_LINK dynamic) +if(USE_MPI) + set(MKL_LINK_LIST static dynamic) +else() + set(MKL_LINK_LIST static dynamic sdl) +endif() +define_param(MKL_LINK DEFAULT_MKL_LINK MKL_LINK_LIST) +check_required_vars(MKL_LINK) + +# Define MKL_INTERFACE +if(SYCL_COMPILER) + if(MKL_INTERFACE AND NOT DEFINED MKL_SYCL_INTERFACE_FULL) + set(MKL_SYCL_INTERFACE_FULL intel_${MKL_INTERFACE}) + endif() + set(DEFAULT_MKL_SYCL_INTERFACE intel_ilp64) + set(MKL_SYCL_INTERFACE_LIST intel_ilp64) + define_param(MKL_SYCL_INTERFACE_FULL DEFAULT_MKL_SYCL_INTERFACE MKL_SYCL_INTERFACE_LIST STATUS) + if(NOT MKL_SYCL_INTERFACE_FULL) + set(SYCL_COMPILER OFF) + mkl_message(STATUS "MKL::MKL_SYCL target will not be available.") + endif() +endif() +if(MKL_ARCH STREQUAL "intel64") + set(IFACE_TYPE intel) + if(GNU_Fortran_COMPILER) + set(IFACE_TYPE gf) + endif() + if(MKL_INTERFACE) + set(MKL_INTERFACE_FULL ${IFACE_TYPE}_${MKL_INTERFACE}) + endif() + set(DEFAULT_MKL_INTERFACE ${IFACE_TYPE}_ilp64) + set(MKL_INTERFACE_LIST ${IFACE_TYPE}_ilp64 ${IFACE_TYPE}_lp64) + define_param(MKL_INTERFACE_FULL DEFAULT_MKL_INTERFACE MKL_INTERFACE_LIST) +else() + if(WIN32) + set(MKL_INTERFACE_FULL intel_c) + elseif(NOT APPLE) + if(GNU_Fortran_COMPILER) + set(MKL_INTERFACE_FULL gf) + else() + set(MKL_INTERFACE_FULL intel) + endif() + else() + mkl_message(FATAL_ERROR "OSX does not support MKL_ARCH ia32.") + endif() +endif() +if(MKL_INTERFACE_FULL MATCHES "ilp64") + set(MKL_INTERFACE "ilp64") +else() + set(MKL_INTERFACE "lp64") +endif() +check_required_vars(MKL_INTERFACE_FULL) + +# Define oneMKL headers +find_path(MKL_INCLUDE mkl.h + HINTS ${MKL_ROOT} + PATH_SUFFIXES include + NO_DEFAULT_PATH) +check_required_vars(MKL_INCLUDE) + +# Add pre-built F95 Interface Modules +if(INTEL_COMPILER AND (ENABLE_BLAS95 OR ENABLE_LAPACK95)) + if(MKL_ARCH STREQUAL "intel64") + list(APPEND MKL_INCLUDE "${MKL_ROOT}/include/mkl/${MKL_ARCH}/${MKL_INTERFACE}") + else() + list(APPEND MKL_INCLUDE "${MKL_ROOT}/include/mkl/${MKL_ARCH}") + endif() +endif() + +# Define MKL_THREADING +# All APIs support sequential threading +# SYCL API supports oneTBB and OpenMP threadings, but OpenMP threading might have composability problem on CPU device with other SYCL kernels +if(SYCL_COMPILER) + set(MKL_SYCL_THREADING_LIST "sequential" "intel_thread" "tbb_thread") + set(DEFAULT_MKL_SYCL_THREADING tbb_thread) + if(NOT DEFINED MKL_SYCL_THREADING) + set(MKL_SYCL_THREADING ${MKL_THREADING}) + endif() + define_param(MKL_SYCL_THREADING DEFAULT_MKL_SYCL_THREADING MKL_SYCL_THREADING_LIST STATUS) + if(NOT MKL_SYCL_THREADING) + set(SYCL_COMPILER OFF) + mkl_message(STATUS "MKL::MKL_SYCL target will not be available.") + endif() + if(MKL_SYCL_THREADING STREQUAL "intel_thread") + mkl_message(STATUS "Using MKL::MKL_SYCL* targets with intel_thread may have potential composability problems on CPU device with other SYCL kernels.") + add_custom_target(MKL_SYCL_MESSAGE + COMMAND ${CMAKE_COMMAND} -E cmake_echo_color --red + "Warning: Using MKL::MKL_SYCL* targets with intel_thread may have potential composability problems on CPU device with other SYCL kernels.") + endif() +endif() +# C, Fortran API +set(MKL_THREADING_LIST "sequential" "intel_thread" "tbb_thread") +set(DEFAULT_MKL_THREADING intel_thread) +if(PGI_COMPILER) + # PGI compiler supports PGI OpenMP threading, additionally + list(APPEND MKL_THREADING_LIST pgi_thread) + # PGI compiler does not support oneTBB threading + list(REMOVE_ITEM MKL_THREADING_LIST tbb_thread) + if(WIN32) + # PGI 19.10 and 20.1 on Windows, do not support Intel OpenMP threading + list(REMOVE_ITEM MKL_THREADING_LIST intel_thread) + set(DEFAULT_MKL_THREADING pgi_thread) + endif() +elseif(GNU_C_COMPILER OR GNU_Fortran_COMPILER OR CLANG_COMPILER) + list(APPEND MKL_THREADING_LIST gnu_thread) +else() + # Intel and Microsoft compilers + # Nothing to do, only for completeness +endif() +define_param(MKL_THREADING DEFAULT_MKL_THREADING MKL_THREADING_LIST) +check_required_vars(MKL_THREADING) + +# Define MKL_MPI +if(USE_MPI) + set(DEFAULT_MKL_MPI intelmpi) + if(UNIX) + if(APPLE) + # Override defaults for OSX + set(DEFAULT_MKL_MPI mpich) + set(MKL_MPI_LIST mpich) + else() + set(MKL_MPI_LIST intelmpi openmpi mpich mpich2) + endif() + else() + # Windows + set(MKL_MPI_LIST intelmpi mshpc msmpi) + endif() + define_param(MKL_MPI DEFAULT_MKL_MPI MKL_MPI_LIST) + # MSMPI is now called MSHPC. MSMPI option exists for backward compatibility. + if(MKL_MPI STREQUAL "mshpc") + set(MKL_MPI msmpi) + endif() + check_required_vars(MKL_MPI) +endif() + +# Provides a list of IMPORTED targets for the project +if(NOT DEFINED MKL_IMPORTED_TARGETS) + set(MKL_IMPORTED_TARGETS "") +endif() + +# Clear temporary variables +set(MKL_C_COPT "") +set(MKL_F_COPT "") +set(MKL_SDL_COPT "") +set(MKL_CXX_COPT "") +set(MKL_SYCL_COPT "") +set(MKL_SYCL_LOPT "") +set(MKL_OFFLOAD_COPT "") +set(MKL_OFFLOAD_LOPT "") + +set(MKL_SUPP_LINK "") # Other link options. Usually at the end of the link-line. +set(MKL_SYCL_SUPP_LINK "") +set(MKL_LINK_LINE "") +set(MKL_SYCL_LINK_LINE "") +set(MKL_ENV_PATH "") # Temporary variable to work with PATH +set(MKL_ENV "") # Exported environment variables + +# Modify PATH variable to make it CMake-friendly +set(OLD_PATH $ENV{PATH}) +string(REPLACE ";" "\;" OLD_PATH "${OLD_PATH}") + +# Compiler options +if(GNU_C_COMPILER OR GNU_Fortran_COMPILER) + if(MKL_ARCH STREQUAL "ia32") + list(APPEND MKL_C_COPT -m32) + list(APPEND MKL_CXX_COPT -m32) + list(APPEND MKL_F_COPT -m32) + else() + list(APPEND MKL_C_COPT -m64) + list(APPEND MKL_CXX_COPT -m64) + list(APPEND MKL_F_COPT -m64) + endif() +endif() + +# Additonal compiler & linker options +if(SYCL_COMPILER) + list(APPEND MKL_SYCL_COPT "-fsycl") + list(APPEND MKL_SYCL_LOPT "-fsycl") + if(MKL_SYCL_LINK STREQUAL "static") + list(APPEND MKL_SYCL_LOPT "-fsycl-device-code-split=per_kernel") + endif() +endif() +if(ENABLE_OMP_OFFLOAD) + if(MKL_LINK STREQUAL "static") + list(APPEND MKL_OFFLOAD_LOPT "-fsycl-device-code-split=per_kernel") + endif() +endif() + +# For OpenMP Offload +if(ENABLE_OMP_OFFLOAD) + if(WIN32) + if(OPENMP_VERSION VERSION_GREATER_EQUAL "5.1") + if("Fortran" IN_LIST CURR_LANGS) + list(APPEND MKL_OFFLOAD_COPT -Qiopenmp -Qopenmp-targets:spir64 -DONEMKL_USE_OPENMP_VERSION=202011) + else() + list(APPEND MKL_OFFLOAD_COPT -Qiopenmp -Qopenmp-targets:spir64 -Qopenmp-version:51 -DONEMKL_USE_OPENMP_VERSION=202011) + endif() + else() + list(APPEND MKL_OFFLOAD_COPT -Qiopenmp -Qopenmp-targets:spir64) + endif() + # -MD and -MDd are manually added here because offload functionality uses DPC++ runtime. + if(CMAKE_BUILD_TYPE MATCHES "Debug|DebInfo") + list(APPEND MKL_OFFLOAD_COPT -MDd) + else() + list(APPEND MKL_OFFLOAD_COPT -MD) + endif() + list(APPEND MKL_OFFLOAD_LOPT -Qiopenmp -Qopenmp-targets:spir64 -fsycl) + set(SKIP_LIBPATH ON) + else() + if(OPENMP_VERSION VERSION_GREATER_EQUAL "5.1") + if("Fortran" IN_LIST CURR_LANGS) + list(APPEND MKL_OFFLOAD_COPT -fiopenmp -fopenmp-targets=spir64 -DONEMKL_USE_OPENMP_VERSION=202011) + else() + list(APPEND MKL_OFFLOAD_COPT -fiopenmp -fopenmp-targets=spir64 -fopenmp-version=51 -DONEMKL_USE_OPENMP_VERSION=202011) + endif() + else () + list(APPEND MKL_OFFLOAD_COPT -fiopenmp -fopenmp-targets=spir64) + endif() + list(APPEND MKL_OFFLOAD_LOPT -fiopenmp -fopenmp-targets=spir64 -fsycl) + if(APPLE) + list(APPEND MKL_SUPP_LINK -lc++) + else() + list(APPEND MKL_SUPP_LINK -lstdc++) + endif() + endif() +endif() + +# For selected Interface +if(SYCL_COMPILER) + list(INSERT MKL_SYCL_COPT 0 "-DMKL_ILP64") +endif() + +if(MKL_INTERFACE_FULL) + if(MKL_ARCH STREQUAL "ia32") + if(GNU_Fortran_COMPILER) + set(MKL_SDL_IFACE_ENV "GNU") + endif() + else() + if(GNU_Fortran_COMPILER) + set(MKL_SDL_IFACE_ENV "GNU,${MKL_INTERFACE}") + else() + set(MKL_SDL_IFACE_ENV "${MKL_INTERFACE}") + endif() + if(MKL_INTERFACE STREQUAL "ilp64") + if("Fortran" IN_LIST CURR_LANGS) + if(INTEL_COMPILER) + if(WIN32) + list(APPEND MKL_F_COPT "-4I8") + else() + list(APPEND MKL_F_COPT "-i8") + endif() + elseif(GNU_Fortran_COMPILER) + list(APPEND MKL_F_COPT "-fdefault-integer-8") + elseif(PGI_COMPILER) + list(APPEND MKL_F_COPT "-i8") + endif() + endif() + list(INSERT MKL_C_COPT 0 "-DMKL_ILP64") + list(INSERT MKL_SDL_COPT 0 "-DMKL_ILP64") + list(INSERT MKL_CXX_COPT 0 "-DMKL_ILP64") + list(INSERT MKL_OFFLOAD_COPT 0 "-DMKL_ILP64") + else() + # lp64 + endif() + endif() + if(MKL_SDL_IFACE_ENV) + string(TOUPPER ${MKL_SDL_IFACE_ENV} MKL_SDL_IFACE_ENV) + endif() +endif() # MKL_INTERFACE_FULL + +# All oneMKL Libraries +if(SYCL_COMPILER) + set(MKL_SYCL_IFACE_LIB mkl_${MKL_SYCL_INTERFACE_FULL}) + if(WIN32 AND CMAKE_BUILD_TYPE MATCHES "Debug|DebInfo" AND MKL_SYCL_THREADING STREQUAL "tbb_thread") + set(MKL_SYCL_THREAD mkl_tbb_threadd) + else() + set(MKL_SYCL_THREAD mkl_${MKL_SYCL_THREADING}) + endif() +endif() +set(MKL_SYCL) +set(MKL_SYCL_LIBS) +list(APPEND MKL_SYCL_LIBS mkl_sycl_blas) +list(APPEND MKL_SYCL_LIBS mkl_sycl_lapack) +list(APPEND MKL_SYCL_LIBS mkl_sycl_dft) +list(APPEND MKL_SYCL_LIBS mkl_sycl_sparse) +list(APPEND MKL_SYCL_LIBS mkl_sycl_data_fitting) +list(APPEND MKL_SYCL_LIBS mkl_sycl_rng) +list(APPEND MKL_SYCL_LIBS mkl_sycl_stats) +list(APPEND MKL_SYCL_LIBS mkl_sycl_vm) +if(NOT MKL_LINK STREQUAL "static") + if(WIN32 AND CMAKE_BUILD_TYPE MATCHES "Debug|DebInfo") + list(TRANSFORM MKL_SYCL_LIBS APPEND "d") + endif() + list(APPEND MKL_SYCL ${MKL_SYCL_LIBS}) + # List for tracking incomplete onemKL package + set(MISSED_MKL_SYCL_LIBS) +else() + if(WIN32 AND CMAKE_BUILD_TYPE MATCHES "Debug|DebInfo") + set(MKL_SYCL mkl_sycld) + else() + set(MKL_SYCL mkl_sycl) + endif() +endif() + +set(MKL_IFACE_LIB mkl_${MKL_INTERFACE_FULL}) +set(MKL_CORE mkl_core) +if(WIN32 AND CMAKE_BUILD_TYPE MATCHES "Debug|DebInfo" AND MKL_THREADING STREQUAL "tbb_thread") + set(MKL_THREAD mkl_tbb_threadd) +else() + set(MKL_THREAD mkl_${MKL_THREADING}) +endif() +set(MKL_SDL mkl_rt) +if(MKL_ARCH STREQUAL "ia32") + set(MKL_BLAS95 mkl_blas95) + set(MKL_LAPACK95 mkl_lapack95) +else() + set(MKL_BLAS95 mkl_blas95_${MKL_INTERFACE}) + set(MKL_LAPACK95 mkl_lapack95_${MKL_INTERFACE}) +endif() +# BLACS +set(MKL_BLACS mkl_blacs_${MKL_MPI}_${MKL_INTERFACE}) +if(UNIX AND NOT APPLE AND MKL_MPI MATCHES "mpich") + # MPICH is compatible with INTELMPI Wrappers on Linux + set(MKL_BLACS mkl_blacs_intelmpi_${MKL_INTERFACE}) +endif() +if(WIN32) + if(MKL_MPI STREQUAL "msmpi") + if("Fortran" IN_LIST CURR_LANGS) + list(APPEND MKL_SUPP_LINK "msmpifec.lib") + endif() + # MSMPI and MSHPC are supported with the same BLACS library + set(MKL_BLACS mkl_blacs_msmpi_${MKL_INTERFACE}) + if(NOT MKL_LINK STREQUAL "static") + set(MKL_BLACS mkl_blacs_${MKL_INTERFACE}) + set(MKL_BLACS_ENV MSMPI) + endif() + elseif(MKL_MPI STREQUAL "intelmpi" AND NOT MKL_LINK STREQUAL "static") + set(MKL_BLACS mkl_blacs_${MKL_INTERFACE}) + set(MKL_BLACS_ENV INTELMPI) + endif() +endif() +# CDFT & SCALAPACK +set(MKL_CDFT mkl_cdft_core) +set(MKL_SCALAPACK mkl_scalapack_${MKL_INTERFACE}) + + +if(UNIX AND NOT APPLE) + if(MKL_LINK STREQUAL "static" OR MKL_SYCL_LINK STREQUAL "static") + set(START_GROUP "-Wl,--start-group") + set(END_GROUP "-Wl,--end-group") + if(SYCL_COMPILER) + set(SYCL_EXPORT_DYNAMIC "-Wl,-export-dynamic") + endif() + if(ENABLE_OMP_OFFLOAD) + set(EXPORT_DYNAMIC "-Wl,-export-dynamic") + endif() + endif() + if(MKL_LINK STREQUAL "dynamic") + set(MKL_RPATH "-Wl,-rpath=$") + if((GNU_Fortran_COMPILER OR PGI_COMPILER) AND "Fortran" IN_LIST CURR_LANGS) + set(NO_AS_NEEDED -Wl,--no-as-needed) + endif() + endif() + if(MKL_SYCL_LINK STREQUAL "dynamic") + set(MKL_SYCL_RPATH "-Wl,-rpath=$") + endif() + if(MKL_LINK STREQUAL "sdl") + set(MKL_RPATH "-Wl,-rpath=$") + endif() +endif() + +# Create a list of requested libraries, based on input options (MKL_LIBRARIES) +# Create full link-line in MKL_LINK_LINE +if(SYCL_COMPILER) + list(APPEND MKL_SYCL_LIBRARIES ${MKL_SYCL} ${MKL_SYCL_IFACE_LIB} ${MKL_SYCL_THREAD} ${MKL_CORE}) + list(TRANSFORM MKL_SYCL PREPEND MKL:: OUTPUT_VARIABLE MKL_SYCL_T) + list(APPEND MKL_SYCL_LINK_LINE ${MKL_SYCL_LOPT} ${SYCL_EXPORT_DYNAMIC} ${NO_AS_NEEDED} ${MKL_SYCL_RPATH} + ${MKL_SYCL_T} ${START_GROUP} MKL::${MKL_SYCL_IFACE_LIB} MKL::${MKL_SYCL_THREAD} MKL::${MKL_CORE} ${END_GROUP}) +endif() +list(APPEND MKL_LINK_LINE $,${MKL_OFFLOAD_LOPT},> + ${EXPORT_DYNAMIC} ${NO_AS_NEEDED} ${MKL_RPATH}) +if(ENABLE_BLAS95) + list(APPEND MKL_LIBRARIES ${MKL_BLAS95}) + list(APPEND MKL_LINK_LINE MKL::${MKL_BLAS95}) +endif() +if(ENABLE_LAPACK95) + list(APPEND MKL_LIBRARIES ${MKL_LAPACK95}) + list(APPEND MKL_LINK_LINE MKL::${MKL_LAPACK95}) +endif() +if(ENABLE_SCALAPACK) + list(APPEND MKL_LIBRARIES ${MKL_SCALAPACK}) + list(APPEND MKL_LINK_LINE MKL::${MKL_SCALAPACK}) +endif() +if(ENABLE_OMP_OFFLOAD AND NOT MKL_LINK STREQUAL "sdl") + list(APPEND MKL_LIBRARIES ${MKL_SYCL}) + list(TRANSFORM MKL_SYCL PREPEND MKL:: OUTPUT_VARIABLE MKL_SYCL_T) + list(APPEND MKL_LINK_LINE ${MKL_SYCL_T}) +endif() +list(APPEND MKL_LINK_LINE ${START_GROUP}) +if(ENABLE_CDFT) + list(APPEND MKL_LIBRARIES ${MKL_CDFT}) + list(APPEND MKL_LINK_LINE MKL::${MKL_CDFT}) +endif() +if(MKL_LINK STREQUAL "sdl") + list(APPEND MKL_LIBRARIES ${MKL_SDL}) + list(APPEND MKL_LINK_LINE MKL::${MKL_SDL}) +else() + list(APPEND MKL_LIBRARIES ${MKL_IFACE_LIB} ${MKL_THREAD} ${MKL_CORE}) + list(APPEND MKL_LINK_LINE MKL::${MKL_IFACE_LIB} MKL::${MKL_THREAD} MKL::${MKL_CORE}) +endif() +if(USE_MPI) + list(APPEND MKL_LIBRARIES ${MKL_BLACS}) + list(APPEND MKL_LINK_LINE MKL::${MKL_BLACS}) +endif() +list(APPEND MKL_LINK_LINE ${END_GROUP}) + +# Find all requested libraries +list(APPEND MKL_REQUESTED_LIBRARIES ${MKL_LIBRARIES}) +if(SYCL_COMPILER) + # If SYCL_COMPILER is still ON, MKL_SYCL_ARCH, MKL_SYCL_LINK, and MKL_SYCL_IFACE_LIB are the same as MKL_ARCH, MKL_LINK, and MKL_IFACE_LIB. + # Hence we can combine the libraries and find them in the following for loop. + # Note that MKL_SYCL_THREADING and MKL_THREADING could be different because of the default value. + list(APPEND MKL_REQUESTED_LIBRARIES ${MKL_SYCL_LIBRARIES}) + list(REMOVE_DUPLICATES MKL_REQUESTED_LIBRARIES) +endif() +foreach(lib ${MKL_REQUESTED_LIBRARIES}) + unset(${lib}_file CACHE) + if(MKL_LINK STREQUAL "static" AND NOT ${lib} STREQUAL ${MKL_SDL}) + find_library(${lib}_file ${LIB_PREFIX}${lib}${LIB_EXT} + PATHS ${MKL_ROOT} + PATH_SUFFIXES "lib${MKL_ARCH_DIR}" + NO_DEFAULT_PATH) + add_library(MKL::${lib} STATIC IMPORTED) + else() + find_library(${lib}_file NAMES ${LIB_PREFIX}${lib}${DLL_EXT} ${lib} + PATHS ${MKL_ROOT} + PATH_SUFFIXES "lib${MKL_ARCH_DIR}" + NO_DEFAULT_PATH) + add_library(MKL::${lib} SHARED IMPORTED) + endif() + if(NOT MKL_LINK STREQUAL "static" AND ${lib} MATCHES "mkl_sycl" AND ${${lib}_file} STREQUAL "${lib}_file-NOTFOUND") + list(APPEND MISSED_MKL_SYCL_LIBS ${lib}) + set(MKL_SYCL_DOMAIN "") + string(REGEX REPLACE "mkl_sycl_" "" MKL_SYCL_DOMAIN ${lib}) + if(WIN32 AND CMAKE_BUILD_TYPE MATCHES "Debug|DebInfo") + string(REGEX REPLACE "d$" "" MKL_SYCL_DOMAIN ${MKL_SYCL_DOMAIN}) + endif() + string(TOUPPER ${MKL_SYCL_DOMAIN} MKL_SYCL_DOMAIN) + mkl_message(WARNING "Could NOT find MKL ${lib} for target MKL::MKL_SYCL::${MKL_SYCL_DOMAIN}") + else() + check_required_vars(${lib}_file) + mkl_message(STATUS "Found ${${lib}_file}") + endif() + # CMP0111, implemented in CMake 3.20+ requires a shared library target on Windows + # to be defined with IMPLIB and LOCATION property. + # It also requires a static library target to be defined with LOCATION property. + # Setting the policy to OLD usage, using cmake_policy() does not work as of 3.20.0, hence the if-else below. + if(WIN32 AND NOT MKL_LINK STREQUAL "static") + set_target_properties(MKL::${lib} PROPERTIES IMPORTED_IMPLIB "${${lib}_file}") + # Find corresponding DLL + set(MKL_DLL_GLOB ${lib}.*.dll) + file(GLOB MKL_DLL_FILE "${MKL_ROOT}/bin${MKL_ARCH_DIR}/${MKL_DLL_GLOB}" + # Legacy oneAPI layout support below + "${MKL_ROOT}/redist/${MKL_ARCH}/${MKL_DLL_GLOB}" + "${MKL_ROOT}/../redist/${MKL_ARCH}/${MKL_DLL_GLOB}" + "${MKL_ROOT}/../redist/${MKL_ARCH}/mkl/${MKL_DLL_GLOB}" + # Support for Conda directory layout + "${MKL_ROOT}/bin/${MKL_DLL_GLOB}" + ) + if(NOT ${lib} STREQUAL ${MKL_IFACE_LIB} AND NOT ${lib} STREQUAL ${MKL_BLAS95} AND NOT ${lib} STREQUAL ${MKL_LAPACK95}) # Windows IFACE libs are static only + list(LENGTH MKL_DLL_FILE MKL_DLL_FILE_LEN) + if(MKL_DLL_FILE_LEN) + # in case multiple versions of the same dll are found, select the highest version + list(SORT MKL_DLL_FILE) + list(REVERSE MKL_DLL_FILE) + list(GET MKL_DLL_FILE 0 MKL_DLL_FILE) + + mkl_message(STATUS "Found DLL: ${MKL_DLL_FILE}") + set_target_properties(MKL::${lib} PROPERTIES IMPORTED_LOCATION "${MKL_DLL_FILE}") + else() + if(${lib} MATCHES "mkl_sycl" AND ${${lib}_file} STREQUAL "${lib}_file-NOTFOUND") + mkl_message(WARNING "Could NOT find ${MKL_DLL_GLOB} for target MKL::MKL_SYCL::${MKL_SYCL_DOMAIN}") + else() + mkl_message(FATAL_ERROR "${MKL_DLL_GLOB} not found") + endif() + endif() + endif() + else() + set_target_properties(MKL::${lib} PROPERTIES IMPORTED_LOCATION "${${lib}_file}") + endif() + list(APPEND MKL_IMPORTED_TARGETS MKL::${lib}) +endforeach() + +# Threading selection +if(MKL_THREADING STREQUAL "tbb_thread" OR MKL_SYCL_THREADING STREQUAL "tbb_thread") + find_package(TBB CONFIG COMPONENTS tbb) + if(TARGET TBB::tbb) + if(MKL_THREADING STREQUAL "tbb_thread") + set(MKL_THREAD_LIB $) + set(MKL_SDL_THREAD_ENV "TBB") + endif() + if(MKL_SYCL_THREADING STREQUAL "tbb_thread") + set(MKL_SYCL_THREAD_LIB $) + endif() + get_property(TBB_LIB TARGET TBB::tbb PROPERTY IMPORTED_LOCATION_RELEASE) + get_filename_component(TBB_LIB_DIR ${TBB_LIB} DIRECTORY) + else() + if(UNIX) + set(TBB_LIBNAME libtbb.so) + else() + set(TBB_LIBNAME tbb.lib) + endif() + find_path(TBB_LIB_DIR ${TBB_LIBNAME} + HINTS $ENV{TBBROOT} $ENV{MKLROOT} ${MKL_ROOT} ${TBB_ROOT} + PATH_SUFFIXES "lib" "lib/intel64/gcc4.4" "lib/intel64/gcc4.8" + "../tbb/lib/intel64/gcc4.4" "../tbb/lib/intel64/gcc4.8" + "../../tbb/latest/lib/intel64/gcc4.8" + "../tbb/lib/intel64/vc14" "lib/intel64/vc14" + ) + find_library(TBB_LIBRARIES NAMES tbb + HINTS $ENV{TBBROOT} $ENV{MKLROOT} ${MKL_ROOT} ${TBB_ROOT} + PATH_SUFFIXES "lib" "lib/intel64/gcc4.4" "lib/intel64/gcc4.8" + "../tbb/lib/intel64/gcc4.4" "../tbb/lib/intel64/gcc4.8" + "../../tbb/latest/lib/intel64/gcc4.8" + "../tbb/lib/intel64/vc14" "lib/intel64/vc14" + ) + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(MKL REQUIRED_VARS TBB_LIBRARIES) + endif() + if(UNIX) + if(CMAKE_SKIP_BUILD_RPATH) + set(TBB_LINK "-L${TBB_LIB_DIR} -ltbb") + else() + set(TBB_LINK "-Wl,-rpath,${TBB_LIB_DIR} -L${TBB_LIB_DIR} -ltbb") + endif() + if(MKL_THREADING STREQUAL "tbb_thread") + list(APPEND MKL_SUPP_LINK ${TBB_LINK}) + if(APPLE) + list(APPEND MKL_SUPP_LINK -lc++) + else() + list(APPEND MKL_SUPP_LINK -lstdc++) + endif() + endif() + if(MKL_SYCL_THREADING STREQUAL "tbb_thread") + list(APPEND MKL_SYCL_SUPP_LINK ${TBB_LINK}) + endif() + endif() + if(WIN32 OR APPLE) + set(MKL_ENV_PATH ${TBB_LIB_DIR}) + endif() +endif() +if(NOT MKL_THREADING STREQUAL "tbb_thread" AND MKL_THREADING MATCHES "_thread") + if(MKL_THREADING STREQUAL "pgi_thread") + list(APPEND MKL_SUPP_LINK -mp -pgf90libs) + set(MKL_SDL_THREAD_ENV "PGI") + elseif(MKL_THREADING STREQUAL "gnu_thread") + list(APPEND MKL_SUPP_LINK -lgomp) + set(MKL_SDL_THREAD_ENV "GNU") + else() + # intel_thread + if(UNIX) + set(MKL_OMP_LIB iomp5) + set(LIB_EXT ".so") + if(APPLE) + set(LIB_EXT ".dylib") + endif() + else() + set(MKL_OMP_LIB libiomp5md) + endif() + set(MKL_SDL_THREAD_ENV "INTEL") + set(OMP_LIBNAME ${LIB_PREFIX}${MKL_OMP_LIB}${LIB_EXT}) + + find_library(OMP_LIBRARY ${OMP_LIBNAME} + HINTS $ENV{LIB} $ENV{LIBRARY_PATH} $ENV{MKLROOT} ${MKL_ROOT} $ENV{CMPLR_ROOT} + PATH_SUFFIXES "lib" "lib/${MKL_ARCH}" + "lib/${MKL_ARCH}_lin" "lib/${MKL_ARCH}_win" + "linux/compiler/lib/${MKL_ARCH}" + "linux/compiler/lib/${MKL_ARCH}_lin" + "windows/compiler/lib/${MKL_ARCH}" + "windows/compiler/lib/${MKL_ARCH}_win" + "../compiler/lib/${MKL_ARCH}_lin" "../compiler/lib/${MKL_ARCH}_win" + "../compiler/lib/${MKL_ARCH}" "../compiler/lib" "compiler/lib" + "../../compiler/latest/linux/compiler/lib/${MKL_ARCH}" + "../../compiler/latest/linux/compiler/lib/${MKL_ARCH}_lin" + "../../compiler/latest/windows/compiler/lib/${MKL_ARCH}" + "../../compiler/latest/windows/compiler/lib/${MKL_ARCH}_win" + "../../compiler/latest/mac/compiler/lib" + NO_DEFAULT_PATH) + if(WIN32) + set(OMP_DLLNAME ${LIB_PREFIX}${MKL_OMP_LIB}.dll) + find_path(OMP_DLL_DIR ${OMP_DLLNAME} + HINTS $ENV{LIB} $ENV{LIBRARY_PATH} $ENV{MKLROOT} ${MKL_ROOT} $ENV{CMPLR_ROOT} + PATH_SUFFIXES "bin" + # Legacy layout support for oneMKL + "redist/${MKL_ARCH}" + "redist/${MKL_ARCH}_win" "redist/${MKL_ARCH}_win/compiler" + "../redist/${MKL_ARCH}/compiler" "../compiler/lib" + "../../compiler/latest/windows/redist/${MKL_ARCH}_win" + "../../compiler/latest/windows/redist/${MKL_ARCH}_win/compiler" + "../../compiler/latest/windows/compiler/redist/${MKL_ARCH}_win" + "../../compiler/latest/windows/compiler/redist/${MKL_ARCH}_win/compiler" + NO_DEFAULT_PATH) + check_required_vars(OMP_DLL_DIR) + set(MKL_ENV_PATH "${OMP_DLL_DIR}") + endif() + + if(WIN32 AND SKIP_LIBPATH) + # Only for Intel OpenMP Offload + set(OMP_LINK "libiomp5md.lib") + else() + set(OMP_LINK "${OMP_LIBRARY}") + if(CMAKE_C_COMPILER_ID STREQUAL "PGI" OR CMAKE_Fortran_COMPILER_ID STREQUAL "PGI") + # Disable PGI OpenMP runtime for correct work of Intel OpenMP runtime + list(APPEND MKL_SUPP_LINK -nomp) + endif() + endif() + check_required_vars(OMP_LIBRARY OMP_LINK) + mkl_message(STATUS "Found ${OMP_LIBRARY}") + if(MKL_SYCL_THREADING STREQUAL "intel_thread") + set(MKL_SYCL_THREAD_LIB ${OMP_LINK}) + endif() + set(MKL_THREAD_LIB ${OMP_LINK}) + endif() +elseif(MKL_THREADING STREQUAL "sequential") + # Sequential threading + set(MKL_SDL_THREAD_ENV "SEQUENTIAL") +endif() # MKL_THREADING + +if(UNIX) + if(SYCL_COMPILER) + list(APPEND MKL_SYCL_SUPP_LINK -lm -ldl -lpthread) + endif() + list(APPEND MKL_SUPP_LINK -lm -ldl -lpthread) +endif() + +if(SYCL_COMPILER OR ENABLE_OMP_OFFLOAD) + if(WIN32) + # Detect sycl library version + if(NOT DEFINED SYCL_LIB_VER_CACHE) + set(SYCL_LIB_VER "") + find_library(SYCL_LIB_DIR ${LIB_PREFIX}sycl${LIB_EXT} + HINTS $ENV{LIB} $ENV{CMPLR_ROOT} + PATH_SUFFIXES "windows/lib" "../lib${MKL_ARCH_DIR}") + if(NOT SYCL_LIB_DIR) + foreach(ver RANGE 6 99) + find_library(SYCL_LIB_DIR ${LIB_PREFIX}sycl${ver}${LIB_EXT} + HINTS $ENV{LIB} $ENV{CMPLR_ROOT} + PATH_SUFFIXES "windows/lib" "../lib${MKL_ARCH_DIR}") + if(SYCL_LIB_DIR) + set(SYCL_LIB_VER ${ver}) + break() + endif() + endforeach() + endif() + set(SYCL_LIB_VER_CACHE ${SYCL_LIB_VER} CACHE STRING "") + endif() + + if(SYCL_COMPILER) + if(CMAKE_BUILD_TYPE MATCHES "Debug|DebInfo") + list(APPEND MKL_SYCL_SUPP_LINK ${LINK_PREFIX}sycl${SYCL_LIB_VER_CACHE}d${LINK_SUFFIX}) + else() + list(APPEND MKL_SYCL_SUPP_LINK ${LINK_PREFIX}sycl${SYCL_LIB_VER_CACHE}${LINK_SUFFIX}) + endif() + endif() + if(ENABLE_OMP_OFFLOAD) + if(CMAKE_BUILD_TYPE MATCHES "Debug|DebInfo") + list(APPEND MKL_SUPP_LINK ${LINK_PREFIX}sycl${SYCL_LIB_VER_CACHE}d${LINK_SUFFIX}) + else() + list(APPEND MKL_SUPP_LINK ${LINK_PREFIX}sycl${SYCL_LIB_VER_CACHE}${LINK_SUFFIX}) + endif() + endif() + else() + if(SYCL_COMPILER) + list(APPEND MKL_SYCL_SUPP_LINK ${LINK_PREFIX}sycl${LINK_SUFFIX}) + endif() + if(ENABLE_OMP_OFFLOAD) + list(APPEND MKL_SUPP_LINK ${LINK_PREFIX}sycl${LINK_SUFFIX}) + endif() + endif() + if(SYCL_COMPILER) + list(APPEND MKL_SYCL_SUPP_LINK ${LINK_PREFIX}OpenCL${LINK_SUFFIX}) + endif() + if(ENABLE_OMP_OFFLOAD) + list(APPEND MKL_SUPP_LINK ${LINK_PREFIX}OpenCL${LINK_SUFFIX}) + endif() +endif() + +# Setup link types based on input options +set(LINK_TYPES "") + +if(SYCL_COMPILER OR ENABLE_OMP_OFFLOAD) +# Remove missed mkl_sycl libraries in case of incomplete oneMKL package + if(MISSED_MKL_SYCL_LIBS) + list(REMOVE_ITEM MKL_SYCL_LIBS ${MISSED_MKL_SYCL_LIBS}) + list(TRANSFORM MISSED_MKL_SYCL_LIBS PREPEND MKL:: OUTPUT_VARIABLE MISSED_MKL_SYCL_TARGETS) + list(REMOVE_ITEM MKL_SYCL_LINK_LINE ${MISSED_MKL_SYCL_TARGETS}) + list(REMOVE_ITEM MKL_LINK_LINE ${MISSED_MKL_SYCL_TARGETS}) + endif() +endif() + +if(SYCL_COMPILER) + if(NOT TARGET MKL::MKL_SYCL) + add_library(MKL::MKL_SYCL INTERFACE IMPORTED GLOBAL) + add_library(MKL::MKL_DPCPP ALIAS MKL::MKL_SYCL) + add_dependencies(MKL::MKL_SYCL MKL_SYCL_MESSAGE) + endif() + target_compile_options(MKL::MKL_SYCL INTERFACE $<$:${MKL_SYCL_COPT}>) + target_link_libraries(MKL::MKL_SYCL INTERFACE ${MKL_SYCL_LINK_LINE} ${MKL_SYCL_THREAD_LIB} ${MKL_SYCL_SUPP_LINK}) + list(APPEND LINK_TYPES MKL::MKL_SYCL) + foreach(lib ${MKL_SYCL_LIBS}) + set(MKL_SYCL_DOMAIN "") + string(REGEX REPLACE "mkl_sycl_" "" MKL_SYCL_DOMAIN ${lib}) + if(WIN32 AND CMAKE_BUILD_TYPE MATCHES "Debug|DebInfo") + string(REGEX REPLACE "d$" "" MKL_SYCL_DOMAIN ${MKL_SYCL_DOMAIN}) + endif() + string(TOUPPER ${MKL_SYCL_DOMAIN} MKL_SYCL_DOMAIN) + add_library(MKL::MKL_SYCL::${MKL_SYCL_DOMAIN} INTERFACE IMPORTED GLOBAL) + add_dependencies(MKL::MKL_SYCL::${MKL_SYCL_DOMAIN} MKL_SYCL_MESSAGE) + # Only dynamic link has domain specific libraries + # Domain specific targets still use mkl_sycl for static + # STREQUAL "${lib}_file-NOTFOUND" + if(MKL_LINK STREQUAL "static") + target_link_libraries(MKL::MKL_SYCL::${MKL_SYCL_DOMAIN} INTERFACE ${MKL_SYCL_LINK_LINE} ${MKL_SYCL_THREAD_LIB} ${MKL_SYCL_SUPP_LINK}) + else() + list(TRANSFORM MKL_SYCL_LINK_LINE REPLACE ".*mkl_sycl.*" "TBD") + list(REMOVE_DUPLICATES MKL_SYCL_LINK_LINE) + list(TRANSFORM MKL_SYCL_LINK_LINE REPLACE "TBD" "MKL::${lib}") + target_link_libraries(MKL::MKL_SYCL::${MKL_SYCL_DOMAIN} INTERFACE ${MKL_SYCL_LINK_LINE} ${MKL_SYCL_THREAD_LIB} ${MKL_SYCL_SUPP_LINK}) + endif() + list(APPEND LINK_TYPES MKL::MKL_SYCL::${MKL_SYCL_DOMAIN}) + endforeach(lib) # MKL_SYCL_LIBS +endif() +# Single target for all C, Fortran link-lines +if(NOT TARGET MKL::MKL) + add_library(MKL::MKL INTERFACE IMPORTED GLOBAL) +endif() +target_compile_options(MKL::MKL INTERFACE + $<$,C>:${MKL_C_COPT}> + $<$,Fortran>:${MKL_F_COPT}> + $<$,CXX>:${MKL_CXX_COPT}> + $,${MKL_OFFLOAD_COPT},>) +target_link_libraries(MKL::MKL INTERFACE ${MKL_LINK_LINE} ${MKL_THREAD_LIB} ${MKL_SUPP_LINK}) +list(APPEND LINK_TYPES MKL::MKL) + +foreach(link ${LINK_TYPES}) + # Set properties on all INTERFACE targets + target_include_directories(${link} BEFORE INTERFACE "${MKL_INCLUDE}") + list(APPEND MKL_IMPORTED_TARGETS ${link}) +endforeach(link) # LINK_TYPES +# oneMKL could be added to implicit directories when it's defined in CPATH +# In order to avoid dependency on CPATH, remove oneMKL from implicit directories +if(CMAKE_C_IMPLICIT_INCLUDE_DIRECTORIES) + list(REMOVE_ITEM CMAKE_C_IMPLICIT_INCLUDE_DIRECTORIES "${MKL_INCLUDE}") +endif() +if(CMAKE_CXX_IMPLICIT_INCLUDE_DIRECTORIES) + list(REMOVE_ITEM CMAKE_CXX_IMPLICIT_INCLUDE_DIRECTORIES "${MKL_INCLUDE}") +endif() + +if(MKL_LINK STREQUAL "sdl") + list(APPEND MKL_ENV "MKL_INTERFACE_LAYER=${MKL_SDL_IFACE_ENV}" "MKL_THREADING_LAYER=${MKL_SDL_THREAD_ENV}") +endif() +if(WIN32 AND NOT MKL_LINK STREQUAL "static") + list(APPEND MKL_ENV "MKL_BLACS_MPI=${MKL_BLACS_ENV}") +endif() + +# Add oneMKL dynamic libraries if RPATH is not defined on Unix +if(UNIX AND CMAKE_SKIP_BUILD_RPATH) + if(MKL_LINK STREQUAL "sdl") + set(MKL_LIB_DIR $) + else() + set(MKL_LIB_DIR $) + endif() + if(APPLE) + list(APPEND MKL_ENV "DYLD_LIBRARY_PATH=${MKL_LIB_DIR}\;$ENV{DYLD_LIBRARY_PATH}") + else() + list(APPEND MKL_ENV "LD_LIBRARY_PATH=${MKL_LIB_DIR}\;$ENV{LD_LIBRARY_PATH}") + endif() +endif() + +# Add oneMKL dynamic libraries to PATH on Windows +if(WIN32 AND NOT MKL_LINK STREQUAL "static") + get_filename_component(MKL_DLL_DIR ${MKL_DLL_FILE} DIRECTORY) + set(MKL_ENV_PATH "${MKL_DLL_DIR}\;${MKL_ENV_PATH}") +endif() + +if(MKL_ENV_PATH) + list(APPEND MKL_ENV "PATH=${MKL_ENV_PATH}\;${OLD_PATH}") + if(APPLE) + list(APPEND MKL_ENV "DYLD_LIBRARY_PATH=${MKL_ENV_PATH}\:${OLD_PATH}") + endif() +endif() + +unset(MKL_DLL_FILE) + +endif() # MKL_LIBRARIES diff --git a/cmake/mkl/MKLConfigVersion.cmake b/cmake/mkl/MKLConfigVersion.cmake new file mode 100755 index 000000000..996cd550f --- /dev/null +++ b/cmake/mkl/MKLConfigVersion.cmake @@ -0,0 +1,59 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(PACKAGE_VERSION "2023.2.0") + +if(PACKAGE_VERSION VERSION_LESS PACKAGE_FIND_VERSION) + set(PACKAGE_VERSION_COMPATIBLE FALSE) +else() + + if("2023.2.0" MATCHES "^([0-9]+)\\.") + set(CVF_VERSION_MAJOR "${CMAKE_MATCH_1}") + else() + set(CVF_VERSION_MAJOR "2024.0.0") + endif() + + if(PACKAGE_FIND_VERSION_MAJOR STREQUAL CVF_VERSION_MAJOR) + set(PACKAGE_VERSION_COMPATIBLE TRUE) + else() + set(PACKAGE_VERSION_COMPATIBLE FALSE) + endif() + + if(PACKAGE_FIND_VERSION STREQUAL PACKAGE_VERSION) + set(PACKAGE_VERSION_EXACT TRUE) + endif() +endif() + + + +if("FALSE") + return() +endif() + + +if("${CMAKE_SIZEOF_VOID_P}" STREQUAL "" OR "" STREQUAL "") + return() +endif() + + +if(NOT CMAKE_SIZEOF_VOID_P STREQUAL "") + math(EXPR installedBits " * 8") + set(PACKAGE_VERSION "${PACKAGE_VERSION} (${installedBits}bit)") + set(PACKAGE_VERSION_UNSUITABLE TRUE) +endif() diff --git a/cmake/oneMKLConfig.cmake b/cmake/oneMKLConfig.cmake index 2a169855c..5baf9024b 100644 --- a/cmake/oneMKLConfig.cmake +++ b/cmake/oneMKLConfig.cmake @@ -21,6 +21,10 @@ list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}) include(CMakeFindDependencyMacro) #find_dependency(MKL REQUIRED) -find_dependency(SYCL REQUIRED) +# try to search for SYCLConfig first to find compiler. If it's not present, use local FindCompiler.cmake +find_package(SYCL QUIET) +if(NOT ${SYCL_FOUND}) + find_package(Compiler REQUIRED) +endif() include("${CMAKE_CURRENT_LIST_DIR}/oneMKLTargets.cmake") diff --git a/conan/profiles/inteldpcpp_lnx b/conan/profiles/inteldpcpp_lnx deleted file mode 100644 index 59192f974..000000000 --- a/conan/profiles/inteldpcpp_lnx +++ /dev/null @@ -1,47 +0,0 @@ -#=============================================================================== -# Copyright 2020-2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# -# -# SPDX-License-Identifier: Apache-2.0 -#=============================================================================== - -COMPILER_PREFIX= - -[settings] -os=Linux -os_build=Linux -arch=x86_64 -arch_build=x86_64 -compiler=clang -compiler.version=11 -compiler.libcxx=libstdc++11 -build_type=Release - -[build_requires] -cmake/[>=3.15] -ninja/1.10.0 - -[env] -CONAN_CMAKE_GENERATOR="Ninja" -CONAN_SYSREQUIRES_MODE=enabled -CONAN_SYSREQUIRES_SUDO=True -CC=$COMPILER_PREFIX/bin/clang -CXX=$COMPILER_PREFIX/bin/dpcpp -CPATH=[$COMPILER_PREFIX/include] -CFLAGS=-fdiagnostics-color -fdiagnostics-show-template-tree -march=native -CXXFLAGS=-fdiagnostics-color -fdiagnostics-show-template-tree -march=native -LDFLAGS=-Wl,-rpath=$COMPILER_PREFIX/lib/ -LIBRARY_PATH=$COMPILER_PREFIX/../../../tbb/latest/lib/intel64/gcc4.8/ -LD_LIBRARY_PATH=[$COMPILER_PREFIX/lib:$COMPILER_PREFIX/lib/x64:$COMPILER_PREFIX/compiler/lib/intel64:$COMPILER_PREFIX/../../../tbb/latest/lib/intel64/gcc4.8/] diff --git a/conan/remotes.txt b/conan/remotes.txt deleted file mode 100644 index 569ddd4e2..000000000 --- a/conan/remotes.txt +++ /dev/null @@ -1,2 +0,0 @@ -conan-center https://api.bintray.com/conan/conan/conan-center True -conan-community https://api.bintray.com/conan/conan-community/conan True diff --git a/conan/settings.yml b/conan/settings.yml deleted file mode 100644 index aa8bdadbd..000000000 --- a/conan/settings.yml +++ /dev/null @@ -1,92 +0,0 @@ - -# Only for cross building, 'os_build/arch_build' is the system that runs Conan -os_build: [Windows, WindowsStore, Linux, Macos, FreeBSD, SunOS, AIX] -arch_build: [x86, x86_64, ppc32be, ppc32, ppc64le, ppc64, armv5el, armv5hf, armv6, armv7, armv7hf, armv7s, armv7k, armv8, armv8_32, armv8.3, sparc, sparcv9, mips, mips64, avr, s390, s390x, sh4le] - -# Only for building cross compilation tools, 'os_target/arch_target' is the system for -# which the tools generate code -os_target: [Windows, Linux, Macos, Android, iOS, watchOS, tvOS, FreeBSD, SunOS, AIX, Arduino, Neutrino] -arch_target: [x86, x86_64, ppc32be, ppc32, ppc64le, ppc64, armv5el, armv5hf, armv6, armv7, armv7hf, armv7s, armv7k, armv8, armv8_32, armv8.3, sparc, sparcv9, mips, mips64, avr, s390, s390x, asm.js, wasm, sh4le] - -# Rest of the settings are "host" settings: -# - For native building/cross building: Where the library/program will run. -# - For building cross compilation tools: Where the cross compiler will run. -os: - Windows: - subsystem: [None, cygwin, msys, msys2, wsl] - WindowsStore: - version: ["8.1", "10.0"] - WindowsCE: - platform: ANY - version: ["5.0", "6.0", "7.0", "8.0"] - Linux: - Macos: - version: [None, "10.6", "10.7", "10.8", "10.9", "10.10", "10.11", "10.12", "10.13", "10.14", "10.15"] - Android: - api_level: ANY - iOS: - version: ["7.0", "7.1", "8.0", "8.1", "8.2", "8.3", "9.0", "9.1", "9.2", "9.3", "10.0", "10.1", "10.2", "10.3", "11.0", "11.1", "11.2", "11.3", "11.4", "12.0", "12.1", "12.2", "12.3", "12.4", "13.0", "13.1"] - watchOS: - version: ["4.0", "4.1", "4.2", "4.3", "5.0", "5.1", "5.2", "5.3", "6.0", "6.1"] - tvOS: - version: ["11.0", "11.1", "11.2", "11.3", "11.4", "12.0", "12.1", "12.2", "12.3", "12.4", "13.0"] - FreeBSD: - SunOS: - AIX: - Arduino: - board: ANY - Emscripten: - Neutrino: - version: ["6.4", "6.5", "6.6", "7.0"] -arch: [x86, x86_64, ppc32be, ppc32, ppc64le, ppc64, armv4, armv4i, armv5el, armv5hf, armv6, armv7, armv7hf, armv7s, armv7k, armv8, armv8_32, armv8.3, sparc, sparcv9, mips, mips64, avr, s390, s390x, asm.js, wasm, sh4le] -compiler: - sun-cc: - version: ["5.10", "5.11", "5.12", "5.13", "5.14"] - threads: [None, posix] - libcxx: [libCstd, libstdcxx, libstlport, libstdc++] - gcc: &gcc - version: ["4.1", "4.4", "4.5", "4.6", "4.7", "4.8", "4.9", - "5", "5.1", "5.2", "5.3", "5.4", "5.5", - "6", "6.1", "6.2", "6.3", "6.4", - "7", "7.1", "7.2", "7.3", "7.4", - "8", "8.1", "8.2", "8.3", - "9", "9.1", "9.2"] - libcxx: [libstdc++, libstdc++11] - threads: [None, posix, win32] # Windows MinGW - exception: [None, dwarf2, sjlj, seh] # Windows MinGW - cppstd: [None, 98, gnu98, 11, gnu11, 14, gnu14, 17, gnu17, 20, gnu20] - Visual Studio: &visual_studio - runtime: [MD, MT, MTd, MDd] - version: ["8", "9", "10", "11", "12", "14", "15", "16"] - toolset: [None, v90, v100, v110, v110_xp, v120, v120_xp, - v140, v140_xp, v140_clang_c2, LLVM-vs2012, LLVM-vs2012_xp, - LLVM-vs2013, LLVM-vs2013_xp, LLVM-vs2014, LLVM-vs2014_xp, - LLVM-vs2017, LLVM-vs2017_xp, v141, v141_xp, v141_clang_c2, v142] - cppstd: [None, 14, 17, 20] - clang: - version: ["3.3", "3.4", "3.5", "3.6", "3.7", "3.8", "3.9", "4.0", - "5.0", "6.0", "7.0", "7.1", - "8", "9", "10", "11"] - libcxx: [libstdc++, libstdc++11, libc++, c++_shared, c++_static] - cppstd: [None, 98, gnu98, 11, gnu11, 14, gnu14, 17, gnu17, 20, gnu20] - apple-clang: - version: ["5.0", "5.1", "6.0", "6.1", "7.0", "7.3", "8.0", "8.1", "9.0", "9.1", "10.0", "11.0"] - libcxx: [libstdc++, libc++] - cppstd: [None, 98, gnu98, 11, gnu11, 14, gnu14, 17, gnu17, 20, gnu20] - intel: - version: ["11", "12", "13", "14", "15", "16", "17", "18", "19"] - base: - gcc: - <<: *gcc - threads: [None] - exception: [None] - Visual Studio: - <<: *visual_studio - qcc: - version: ["4.4", "5.4"] - libcxx: [cxx, gpp, cpp, cpp-ne, accp, acpp-ne, ecpp, ecpp-ne] - -build_type: [None, Debug, Release, RelWithDebInfo, MinSizeRel] - - -cppstd: [None, 98, gnu98, 11, gnu11, 14, gnu14, 17, gnu17, 20, gnu20] # Deprecated, use compiler.cppstd diff --git a/conanfile.py b/conanfile.py deleted file mode 100644 index 930cedc51..000000000 --- a/conanfile.py +++ /dev/null @@ -1,156 +0,0 @@ -#=============================================================================== -# Copyright 2020-2021 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions -# and limitations under the License. -# -# -# SPDX-License-Identifier: Apache-2.0 -#=============================================================================== - -from conans import ConanFile, CMake, tools -from packaging.version import parse -from six import StringIO - -class oneMKLConan(ConanFile): - name = "oneMKL" - version = "0.1.0-beta" - url = "https://github.com/oneapi-src/oneMKL" - description = "oneMKL interfaces is an open-source implementation of oneMKL Data Parallel C++ (DPC++) interfaces according to oneMKL specification that can work with multiple devices (backends) using device specific libraries underneath." - license = "Apache License Version 2.0" - legal_notice = "\ -LEGAL NOTICE: By downloading and using this container or script as applicable \n \ -(the “Software Package”), and any included software or software made available \n \ -for download, you agree to the terms and conditions of the software license \n \ -agreements for the Software Package, which may also include notices, \n \ -disclaimers or license terms for third party software (the “Agreements”) \n \ -located at https://github.com/conan-io/conan/blob/develop/LICENSE.md, the \n \ -THIRD-PARTY-PROGRAMS file and in the README.md file included with the Software Package. \n \ - " - - # Dependencies - netlib_version = "3.7.1" - sphinx_version = "2.4.4" - - settings = "os", "compiler", "build_type", "arch" - options = { - # Build style - "build_shared_libs": [True, False], - "target_domains" : "ANY", - - # Backends - "enable_mklcpu_backend" : [True, False], - "enable_mklgpu_backend" : [True, False], - - # Threading for mklcpu_backend - "enable_mklcpu_thread_tbb": [True, False], - - # Testing - "build_functional_tests" : [True, False], - - # Documentation - "build_doc" : [True, False] - } - default_options = { - "build_shared_libs" : True, - "target_domains" : None, - - "enable_mklcpu_backend" : True, - "enable_mklgpu_backend" : True, - - "enable_mklcpu_thread_tbb": True, - - "build_functional_tests" : True, - - "build_doc" : False, - - # External package options - "lapack:shared": True - } - generators = "cmake" - no_copy_source = True - exports_sources = "cmake/*", "include/*", "tests/*", "CMakeLists.txt" - - - def system_requirements(self): - self.output.info(f"\n {self.legal_notice}") - self.global_system_requirements = True - installer = tools.SystemPackageTool() - if self.options.enable_mklcpu_backend or self.options.enable_mklgpu_backend: - installer.add_repository("\"deb https://apt.repos.intel.com/oneapi all main\"") - installer.install("intel-oneapi-mkl-devel") # User must apt-key add GPG key before they can download oneMKL - if self.options.enable_mklcpu_thread_tbb: - installer.install("intel-oneapi-tbb-devel") # For libtbb.so used during link-time - - - def get_python_exe(self): - # Find supported Python binary - python_exe = "python3" - try: - self.run(f"{python_exe} --version", output=False) - except: - python_exe = "python" - my_buffer = StringIO() - self.run(f"{python_exe} --version", output=my_buffer) - ver_found = parse( my_buffer.getvalue().replace('Python ', '') ) - if ver_found < parse('3.6.0'): - self.output.error(f"Python 3.6.0 or higher required. Found {ver_found}") - return - return python_exe - - - def build_requirements(self): - if self.options.build_functional_tests: - self.build_requires(f"lapack/{self.netlib_version}@conan/stable") - # For Sphinx only - if self.options.build_doc: - # Use pip to install Sphinx as a user package - self.run(f"{self.get_python_exe()} -m pip install sphinx=={self.sphinx_version}") - - - def _cmake(self): - cmake = CMake(self) - return cmake - - - def build(self): - cmake = self._cmake() - cmake.definitions.update({ - # Options - "BUILD_SHARED_LIBS" : self.options.build_shared_libs, - "ENABLE_MKLCPU_BACKEND" : self.options.enable_mklcpu_backend, - "ENABLE_MKLGPU_BACKEND" : self.options.enable_mklgpu_backend, - "ENABLE_MKLCPU_THREAD_TBB" : self.options.enable_mklcpu_thread_tbb, - "BUILD_FUNCTIONAL_TESTS" : self.options.build_functional_tests, - "BUILD_DOC" : self.options.build_doc, - - # Paramaters - # Conan does not officially support oneAPI DPC++ Compiler, hence disable the compiler_id check - "CONAN_DISABLE_CHECK_COMPILER" : True, - "MKL_ROOT" : "/opt/intel/inteloneapi/mkl/latest", - }) - # Pass target_domains definition transparently to CMake, since CMakeLists.txt contains checks and defaults. - if self.options.target_domains != None: - cmake.definitions["TARGET_DOMAINS"] = self.options.target_domains - cmake.configure() - cmake.build() - if tools.get_env("CONAN_RUN_TESTS", default=True) and self.options.build_functional_tests: - cmake.test() - - - def package(self): - cmake = self._cmake() - cmake.install() - - - def package_info(self): - self.cpp_info.libs = ["onemkl"] diff --git a/docs/README.md b/docs/README.md index 32a56f82c..040d22aac 100644 --- a/docs/README.md +++ b/docs/README.md @@ -3,9 +3,8 @@ This folder contains oneMKL documentation in reStructuredText (rST) format. The documentation build step is skipped by default. -To enable building documentation from the main build: -- Set `-o build_doc=True` when building with Conan. For more information see [Building with Conan](../README.md#building-with-conan) -- Set `-DBUILD_DOC=ON` when building with CMake. For more information see [Building with CMake](../README.md#building-with-cmake) +To enable building documentation from the main build, set `-DBUILD_DOC=ON`. +For more information see [Building with CMake](../README.md#building-with-cmake). To build documentation only, use the following commands from the current folder: ```bash diff --git a/docs/building_and_running_tests.rst b/docs/building_and_running_tests.rst new file mode 100644 index 000000000..43d3431af --- /dev/null +++ b/docs/building_and_running_tests.rst @@ -0,0 +1,51 @@ +.. _building_and_running_tests: + +Building and Running Tests +========================== + +The functional tests are enabled by default, and can be enabled/disabled +with the CMake build parameter ``-DBUILD_FUNCTIONAL_TESTS=True/False``. Only +the tests relevant to the enabled backends and target domains will be built. + +Building tests for BLAS and LAPACK domains requires additional libraries for +reference. + +* BLAS: Requires a reference BLAS library. +* LAPACK: Requires a reference LAPACK library. + +For both BLAS and LAPACK, shared libraries supporting both 32 and 64 bit +indexing are required. + +A reference LAPACK implementation (including BLAS) can be built as the +following: + +.. code-block:: bash + + git clone https://github.com/Reference-LAPACK/lapack.git + cd lapack; mkdir -p build; cd build + cmake -DCMAKE_INSTALL_PREFIX=~/lapack -DCBLAS=True -DLAPACK=True -DLAPACKE=True -DBUILD_INDEX64=True -DBUILD_SHARED_LIBS=True .. + cmake --build . -j --target install + cmake -DCMAKE_INSTALL_PREFIX=~/lapack -DCBLAS=True -DLAPACK=True -DLAPACKE=True -DBUILD_INDEX64=False -DBUILD_SHARED_LIBS=True .. + cmake --build . -j --target install + +and then used in oneMKL by setting ``-REF_BLAS_ROOT=/path/to/lapack/install`` +and ``-DREF_LAPACK_ROOT=/path/to/lapack/install``. + +You can re-run tests without re-building the entire project. + +To run the tests, either run test binaries individually, or use ``ctest`` CMake test driver program. + +.. code-block:: bash + + # Run all tests + ctest + # Run only Gpu specific tests + ctest -R Gpu + # Exclude Cpu tests + ctest -E Cpu + +For more ``ctest`` options, refer to `ctest manual page `_. + +When running tests you may encounter the issue ``BACKEND NOT FOUND EXCEPTION``, +you may need to add your ``/lib`` to your +``LD_LIBRARY_PATH`` on Linux. diff --git a/docs/building_the_project.rst b/docs/building_the_project.rst deleted file mode 100644 index 668d3506a..000000000 --- a/docs/building_the_project.rst +++ /dev/null @@ -1,610 +0,0 @@ -.. _building_the_project: - -Building the Project -==================== - -.. _build_setup: - -Build Setup -########### - -#. - Install Intel(R) oneAPI DPC++ Compiler (select the variant as described in - :ref:`Selecting a Compiler`). - -#. - Clone this project to ````\ , where ```` - is the root directory of this repository. - -#. - You can :ref:`Build with Conan ` to automate the - process of getting dependencies or you can download and install the - required dependencies manually and - :ref:`Build with CMake ` directly. - -.. note:: - Conan package manager automates the process of getting required packages - so that you do not have to go to different web location and follow different - instructions to install them. - -.. _build_setup_with_hipsycl: - -Build Setup with hipSYCL -######################## - -#. - Make sure that the dependencies of hipSYCL are fulfilled. For a detailed - description, see the - `hipSYCL installation readme `_. - -#. - Install hipSYCL with the prefered backends enabled. hipSYCL supports - various backends. You can customize support for the target system at - compile time by setting the appropriate configuration flags; see the - `hipSYCL documentation `_ - for instructions. - -#. - Install `AMD rocBLAS `_. - -#. - Clone this project to ````, where ```` is - the root directory of this repository. - -#. - Download and install the required dependencies manually and - :ref:`Build with CMake `. - -.. _building_with_conan: - -Building with Conan -################### - -** This method currently works on Linux* only ** - -** Make sure you have completed :ref:`Build Setup `. ** - -.. note:: - To understand how dependencies are resolved, refer to "Product and Version - Information" under - `Support and Requirements `_. - For details about Conan package manager, refer to the - `Conan Documentation `_. - -Getting Conan -^^^^^^^^^^^^^ - -Conan can be `installed `_ from pip: - -.. code-block:: bash - - pip3 install conan - -Setting up Conan -^^^^^^^^^^^^^^^^ - -Conan Default Directory -~~~~~~~~~~~~~~~~~~~~~~~ - -Conan stores all files and data in ``~/.conan``. If you are fine with this -behavior, you can skip to the :ref:`Conan Profiles ` section. - -To change this behavior, set the environment variable ``CONAN_USER_HOME`` to a -path of your choice. A ``.conan/`` directory will be created in this path and -future Conan commands will use this directory to find configuration files and -download dependent packages. Packages will be downloaded into -``$CONAN_USER_HOME/data``. To change the ``"/data"`` part of this directory, -refer to the ``[storage]`` section of ``conan.conf`` file. - -To make this setting persistent across terminal sessions, you can add the -line below to your ``~/.bashrc`` or custom runscript. Refer to the -`Conan Documentation `_ -for more details. - -.. code-block:: sh - - export CONAN_USER_HOME=/usr/local/my_workspace/conan_cache - -.. _conan-profiles: - -Conan Profiles -~~~~~~~~~~~~~~ - -Profiles are a way for Conan to determine a basic environment to use for -building a project. This project ships with profiles for: - - -* Intel(R) oneAPI DPC++ Compiler for x86 CPU and Intel GPU backend: ``inteldpcpp_lnx`` - - -#. Open the profile you wish to use from ``/conan/profiles/`` - and set ``COMPILER_PREFIX`` to the path to the root folder of compiler. - The root folder is the one that contains the ``bin`` and ``lib`` - directories. For example, Intel(R) oneAPI DPC++ Compiler root folder for - default installation on Linux is - ``/opt/intel/inteloneapi/compiler//linux``. The user can define a - custom path for installing the compiler. - -.. code-block:: ini - - COMPILER_PREFIX= - - -#. - You can customize the ``[env]`` section of the profile based on individual - requirements. - -#. - Install configurations for this project: - - .. code-block:: sh - - # Inside - $ conan config install conan/ - - This command installs all contents of ``/conan/``\ , most - importantly profiles, to conan default directory. - -.. note:: - If you change the profile, you must re-run the above command before you can - use the new profile. - -Building -^^^^^^^^ - -#. - Out-of-source build - - .. code-block:: bash - - # Inside - mkdir build && cd build - -#. - If you choose to build backends with the Intel(R) oneAPI - Math Kernel Library, install the GPG key as mentioned here: - https://software.intel.com/en-us/articles/oneapi-repo-instructions#aptpkg - -#. - Install dependencies - - .. code-block:: sh - - conan install .. --profile --build missing [-o =] [-o =] - - The ``conan install`` command downloads and installs all requirements for - the oneMKL DPC++ Interfaces project as defined in - ``/conanfile.py`` based on the options passed. It also - creates ``conanbuildinfo.cmake`` file that contains information about all - dependencies and their directories. This file is used in top-level - ``CMakeLists.txt``. - -``-pr | --profile `` -Defines a profile for Conan to use for building the project. - -``-b | --build `` -Tells Conan to build or re-build a specific package. If ``missing`` is passed -as a value, all missing packages are built. This option is recommended when -you build the project for the first time, because it caches required packages. -You can skip this option for later use of this command. - - -#. Build Project - .. code-block:: sh - - conan build .. [--configure] [--build] [--test] # Default is all - -The ``conan build`` command executes the ``build()`` procedure from -``/conanfile.py``. Since this project uses ``CMake``\ , you -can choose to ``configure``\ , ``build``\ , ``test`` individually or perform -all steps by passing no optional arguments. - - -#. Optionally, you can also install the package. Similar to ``cmake --install . --prefix ``. - -.. code-block:: sh - - conan package .. --build-folder . --install-folder - -``-bf | --build-folder`` -Tells Conan where to find the built project. - -``-if | --install-folder`` -Tells Conan where to install the package. It is similar to specifying ``CMAKE_INSTALL_PREFIX`` - -.. note:: - For a detailed list of commands and options, refer to the - `Conan Command Reference `_. - -Conan Build Options -^^^^^^^^^^^^^^^^^^^ - -Backend-Related Options -~~~~~~~~~~~~~~~~~~~~~~~ - -The following ``options`` are available to pass on ``conan install`` when -building the oneMKL library: - - -* ``build_shared_libs=[True | False]``. Setting it to ``True`` enables the building of dynamic libraries. The default value is ``True``. -* ``target_domains=[]``. Setting it to ``blas`` or any other list of domain(s), enables building of those specific domain(s) only. If not defined, the default value is all supported domains. -* ``enable_mklcpu_backend=[True | False]``. Setting it to ``True`` enables the building of oneMKL mklcpu backend. The default value is ``True``. -* ``enable_mklgpu_backend=[True | False]``. Setting it to ``True`` enables the building of oneMKL mklgpu backend. The default value is ``True``. -* ``enable_mklcpu_thread_tbb=[True | False]``. Setting it to ``True`` enables oneMKL on CPU with TBB threading instead of sequential. The default value is ``True``. - -Testing-Related Options -~~~~~~~~~~~~~~~~~~~~~~~ - -* ``build_functional_tests=[True | False]``. Setting it to ``True`` enables - the building of functional tests. The default value is ``True``. - -Example-Related Options -~~~~~~~~~~~~~~~~~~~~~~~ - -* ``build_examples=[True | False]``. Setting it to ``True`` enables - the building of examples. The default value is ``True``. Compile_time_dispatching examples will always be built if this value is set to true. Run_time_dispatching examples will be build if both this value and ``build_shared_libs`` is set to true - -Documentation -~~~~~~~~~~~~~ - -* ``build_doc=[True | False]``. Setting it to ``True`` enables the building of rst files to generate HTML files for updated documentation. The default value is ``False``. - -.. note:: - For a mapping between Conan and CMake options, refer to - :ref:`Building with CMake `. - -Example -^^^^^^^ - -Build oneMKL as a static library for oneMKL cpu and gpu backend: -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -.. code-block:: sh - - # Inside - mkdir build && cd build - conan install .. --build missing --profile inteldpcpp_lnx -o build_shared_libs=False - conan build .. - -.. _building_with_cmake: - -Building with CMake -################### - -#. - Make sure you have completed `Build Setup <#build-setup>`_. - -#. - Build and install all required `dependencies <#software-requirements>`_. - -Building for oneMKL -^^^^^^^^^^^^^^^^^^^ - -* On Linux* - - .. code-block:: bash - - # Inside - mkdir build && cd build - cmake .. [-DCMAKE_CXX_COMPILER=/bin/dpcpp] # required only if dpcpp is not found in environment variable PATH - [-DCMAKE_C_COMPILER=/bin/icx] # required only if icx is not found in environment variable PATH - [-DMKL_ROOT=] # required only if environment variable MKLROOT is not set - [-DREF_BLAS_ROOT=] # required only for testing - [-DREF_LAPACK_ROOT=] # required only for testing - cmake --build . - ctest - cmake --install . --prefix - -* On Windows* - - .. code-block:: bash - - # Inside - md build && cd build - cmake .. -G Ninja [-DCMAKE_CXX_COMPILER=\bin\dpcpp] # required only if dpcpp is not found in environment variable PATH - [-DCMAKE_C_COMPILER=\bin\icx] # required only if icx is not found in environment variable PATH - [-DMKL_ROOT=] # required only if environment variable MKLROOT is not set - [-DREF_BLAS_ROOT=] # required only for testing - [-DREF_LAPACK_ROOT=] # required only for testing - ninja - ctest - cmake --install . --prefix - -Building for CUDA (with hipSYCL) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -* On Linux* - -With the cuBLAS backend: - -.. code-block:: bash - - # Inside - mkdir build && cd build - cmake .. -DENABLE_CUBLAS_BACKEND=True \ - -DENABLE_MKLGPU_BACKEND=False # Disable all backends except for cuBLAS - -DENABLE_MKLCPU_BACKEND=False \ - -DENABLE_NETLIB_BACKEND=False \ - -DENABLE_ROCBLAS_BACKEND=False \ - -DHIPSYCL_TARGETS=cuda:sm_75 \ # Specify the targeted device architectures - -DONEMKL_SYCL_IMPLEMENTATION=hipSYCL \ - [-DREF_BLAS_ROOT=] # required only for testing - cmake --build . - ctest - cmake --install . --prefix - -To build with the cuRAND backend instead simply replace: - -.. code-block:: bash - - -DENABLE_CUBLAS_BACKEND=True \ - -With: - -.. code-block:: bash - - -DENABLE_CURAND_BACKEND=True \ - - -Building for CUDA (with clang++) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -* On Linux* - -With the cuBLAS backend: - -.. code-block:: bash - - # Inside - mkdir build && cd build - cmake .. [-DCMAKE_CXX_COMPILER=/bin/clang++] # required only if clang++ is not found in environment variable PATH - [-DCMAKE_C_COMPILER=/bin/clang] # required only if clang is not found in environment variable PATH - -DENABLE_CUBLAS_BACKEND=True \ - -DENABLE_MKLCPU_BACKEND=False # disable Intel MKL CPU backend - -DENABLE_MKLGPU_BACKEND=False # disable Intel MKL GPU backend - [-DREF_BLAS_ROOT=] # required only for testing - cmake --build . - ctest - cmake --install . --prefix - -To build with the cuRAND backend instead simply replace: - -.. code-block:: bash - - -DENABLE_CUBLAS_BACKEND=True \ - -With: - -.. code-block:: bash - - -DENABLE_CURAND_BACKEND=True \ - - -Building for ROCm (with hipSYCL) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -With the AMD rocBLAS backend: - -* On Linux* - -.. code-block:: bash - - # Inside - mkdir build && cd build - cmake .. -DENABLE_CUBLAS_BACKEND=False \ - -DENABLE_MKLCPU_BACKEND=False/True # hipSYCL supports MKLCPU backend - -DENABLE_NETLIB_BACKEND=False/True # hipSYCL supports NETLIB backend - -DENABLE_MKLGPU_BACKEND=False # disable Intel MKL GPU backend - -DENABLE_ROCBLAS_BACKEND=True \ - -DTARGET_DOMAINS=blas # hipSYCL supports BLAS and RNG domains - -DHIPSYCL_TARGETS=omp\;hip:gfx906 # Specify the targetted device architectures - -DONEMKL_SYCL_IMPLEMENTATION=hipSYCL # Use the hipSYCL cmake integration - [-DREF_BLAS_ROOT=] # required only for testing - cmake --build . - ctest - cmake --install . --prefix - -To build with the rocRAND backend instead simply replace: - -.. code-block:: bash - - -DENABLE_ROCBLAS_BACKEND=True \ - -DTARGET_DOMAINS=blas - -With: - -.. code-block:: bash - - -DENABLE_ROCRAND_BACKEND=True \ - -DTARGET_DOMAINS=rng - -Building for ROCm (with clang++) -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -With the AMD rocBLAS backend: - - -* On Linux* - -.. code-block:: bash - - # Inside - mkdir build && cd build - cmake .. [-DCMAKE_CXX_COMPILER=/bin/clang++] # required only if clang++ is not found in environment variable PATH - [-DCMAKE_C_COMPILER=/bin/clang] # required only if clang is not found in environment variable PATH - -DENABLE_CUBLAS_BACKEND=False \ - -DENABLE_MKLCPU_BACKEND=False \ # disable Intel MKL CPU backend - -DENABLE_MKLGPU_BACKEND=False \ # disable Intel MKL GPU backend - -DENABLE_ROCBLAS_BACKEND=True \ - -DHIP_TARGETS=gfx90a \ # Specify the targetted device architectures - [-DREF_BLAS_ROOT=] # required only for testing - cmake --build . - export SYCL_DEVICE_FILTER=HIP - ctest - cmake --install . --prefix - -To build with the rocRAND backend instead simply replace: - -.. code-block:: bash - - -DENABLE_ROCBLAS_BACKEND=True \ - -DTARGET_DOMAINS=blas - -With: - -.. code-block:: bash - - -DENABLE_ROCRAND_BACKEND=True \ - -DTARGET_DOMAINS=rng - -To build with the rocSOLVER backend instead simply replace: - -.. code-block:: bash\ - - -DENABLE_ROCBLAS_BACKEND=True \ - -DTARGET_DOMAINS=blas -With: - -.. code-block:: bash - - -DENABLE_ROCSOLVER_BACKEND=True \ - -DTARGET_DOMAINS=lapack - -**AMD GPU device architectures** - -The device architecture can be retrieved via the ``rocminfo`` tool. The architecture will be displayed in the ``Name:`` row. - -A few often-used architectures are listed below: - -.. list-table:: - :header-rows: 1 - - * - Architecture - - AMD GPU name - * - gfx906 - - | AMD Radeon Instinct(TM) MI50/60 Accelerator - | AMD Radeon(TM) (Pro) VII Graphics Card - * - gfx908 - - AMD Instinct(TM) MI 100 Accelerator - * - gfx900 - - | Radeon Instinct(TM) MI 25 Accelerator - | Radeon(TM) RX Vega 64/56 Graphics - -Build Options -^^^^^^^^^^^^^ - -When building oneMKL the SYCL implementation can be determined, by setting the -``ONEMKL_SYCL_IMPLEMENTATION`` option. Possible values are: - -* ``dpc++`` (default) for the - `Intel(R) oneAPI DPC++ Compiler `_ - and for the ``clang++`` from - `Intel project for LLVM* technology `_ compilers. -* ``hipsycl`` for the `hipSYCL `_ SYCL implementation. - -All options specified in the Conan section are available to CMake. You can -specify these options using ``-D=``. - -The following table provides a detailed mapping of options between Conan and -CMake. - -.. list-table:: - :header-rows: 1 - - * - Conan Option - - CMake Option - - Supported Values - - Default Value - * - build_shared_libs - - BUILD_SHARED_LIBS - - True, False - - True - * - enable_mklcpu_backend - - ENABLE_MKLCPU_BACKEND - - True, False - - True - * - enable_mklgpu_backend - - ENABLE_MKLGPU_BACKEND - - True, False - - True - * - *Not Supported* - - ENABLE_CUBLAS_BACKEND - - True, False - - False - * - *Not Supported* - - ENABLE_CUSOLVER_BACKEND - - True, False - - False - * - *Not Supported* - - ENABLE_CURAND_BACKEND - - True, False - - False - * - *Not Supported* - - ENABLE_NETLIB_BACKEND - - True, False - - False - * - *Not Supported* - - ENABLE_ROCBLAS_BACKEND - - True, False - - False - * - enable_mklcpu_thread_tbb - - ENABLE_MKLCPU_THREAD_TBB - - True, False - - True - * - build_functional_tests - - BUILD_FUNCTIONAL_TESTS - - True, False - - True - * - build_examples - - BUILD_EXAMPLES - - True, False - - True - * - build_doc - - BUILD_DOC - - True, False - - False - * - target_domains (list) - - TARGET_DOMAINS (list) - - blas, lapack, rng, dft - - All domains - -.. note:: - ``build_functional_tests`` and related CMake options affect all domains at a - global scope. - -  -.. note:: - When building with hipSYCL, you must additionally provide - ``-DHIPSYCL_TARGETS`` according to the targeted hardware. For the options, - see the tables in the hipSYCL-specific sections. - - -.. note:: - When building with clang++ for AMD backends, you must additionally set - ``SYCL_DEVICE_FILTER`` to ``HIP`` and provide ``-DHIP_TARGETS`` according to - the targeted hardware. This backend has only been tested for the ``gfx90a`` - architecture (MI210) at the time of writing. - -.. note:: - When building with ``BUILD_FUNCTIONAL_TESTS=yes`` (default option) only single CUDA backend can be built - (`#270 `_). - -.. _project_cleanup: - -Project Cleanup -############### - -Most use-cases involve building the project without the need to cleanup the -build directory. However, if you wish to cleanup the build directory, you can -delete the ``build`` folder and create a new one. If you wish to cleanup the -build files but retain the build configuration, following commands will help -you do so. They apply to both ``Conan`` and ``CMake`` methods of building -this project. - -.. code-block:: sh - - # If you use "GNU/Unix Makefiles" for building, - make clean - - # If you use "Ninja" for building - ninja -t clean diff --git a/docs/building_the_project_with_adaptivecpp.rst b/docs/building_the_project_with_adaptivecpp.rst new file mode 100644 index 000000000..98c763b90 --- /dev/null +++ b/docs/building_the_project_with_adaptivecpp.rst @@ -0,0 +1,171 @@ +.. _building_the_project_with_adaptivecpp: + +Building the Project with AdaptiveCpp +===================================== + +.. _build_setup_with_adaptivecpp: + +Environment Setup +################# + +#. + Build and install AdaptiveCpp. For a detailed description of available + AdaptiveCpp backends, their dependencies, and installation, see the + `AdaptiveCpp installation readme + `_. + +#. + Clone this project. The root directory of the cloned repository will be + referred to as ````. + +#. + Download and install the `required dependencies + `_ + manually. + +Build Commands +############### + +In most cases, building oneMKL Interfaces is as simple as setting the compiler and +selecting the desired backends to build with. + +On Linux (other OSes are not supported with the AdaptiveCpp compiler): + +.. code-block:: bash + + # Inside + mkdir build && cd build + cmake .. -DONEMKL_SYCL_IMPLEMENTATION=hipsycl \ # Indicate that AdaptiveCpp is being used. + -DENABLE_MKLGPU_BACKEND=False \ # MKLGPU backend is not supported by AdaptiveCpp + -DENABLE__BACKEND=True \ # Enable backend(s) (optional) + -DENABLE__BACKEND=True \ # Multiple backends can be enabled at once. + -DHIPSYCL_TARGETS=omp/;hip:gfx90a,gfx906 \ # Set target architectures depending on supported devices. + -DBUILD_FUNCTIONAL_TESTS=False \ # See section *Building the tests* for more on building tests. True by default. + -DBUILD_EXAMPLES=False # Optional: True by default. + cmake --build . + cmake --install . --prefix # required to have full package structure + +Backends should be enabled by setting ``-DENABLE__BACKEND=True`` for +each desired backend. By default, the ``MKLGPU`` and ``MKLCPU`` backends are +enabled, but ``MKLGPU`` must be disabled with AdaptiveCpp. The supported +backends for the compilers are given in the table at `oneMKL supported +configurations table +`_, +and the CMake option names are given in the table below. Some backends may +require additional parameters to be set. See the relevant section below for +additional guidance. The target architectures must be specified with +``HIP_TARGETS``. See the `AdaptiveCpp documentation +`_. + +If a backend library supports multiple domains (i.e. BLAS, RNG), it may be +desirable to only enable selected domains. For this, the ``TARGET_DOMAINS`` +variable should be set. For further details, see :ref:`_build_target_domains`. + +By default, the library also additionally builds examples and tests. These can +be disabled by setting the parameters ``BUILD_FUNCTIONAL_TESTS`` and +``BUILD_EXAMPLES`` to False. Building the functional tests may require additional +external libraries. See the section :ref:`building_and_running_tests` for more +information. + +The most important supported build options are: + +.. list-table:: + :header-rows: 1 + + * - CMake Option + - Supported Values + - Default Value + * - ENABLE_MKLCPU_BACKEND + - True, False + - True + * - ENABLE_CUBLAS_BACKEND + - True, False + - False + * - ENABLE_CURAND_BACKEND + - True, False + - False + * - ENABLE_NETLIB_BACKEND + - True, False + - False + * - ENABLE_ROCBLAS_BACKEND + - True, False + - False + * - ENABLE_ROCRAND_BACKEND + - True, False + - False + * - ENABLE_MKLCPU_THREAD_TBB + - True, False + - True + * - BUILD_FUNCTIONAL_TESTS + - True, False + - True + * - BUILD_EXAMPLES + - True, False + - True + * - TARGET_DOMAINS (list) + - blas, rng + - All supported domains + +Some additional build options are given in +:ref:`build_additional_options_dpcpp`. + +Backends +######## + +.. _build_for_cuda_adaptivecpp: + +Building for CUDA +~~~~~~~~~~~~~~~~~ + +The CUDA backends can be enabled with ``ENABLE_CUBLAS_BACKEND`` and +``ENABLE_CURAND_BACKEND``. + +The target architecture must be set using the ``HIPSYCL_TARGETS`` parameter. For +example, to target a Nvidia A100 (Ampere architecture), set +``-DHIPSYCL_TARGETS=cuda:sm_80``, where the figure ``80`` corresponds to a CUDA +compute capability of 8.0. The correspondence between compute capabilities and +Nvidia GPU products is given on the `Nvidia website +`_. Multiple architectures can be +enabled using a comma separated list. See the `AdaptiveCpp documentation +`_. + +No additional parameters are required for using CUDA libraries. In most cases, +the CUDA libraries should be found automatically by CMake. + +.. _build_for_rocm_adaptivecpp: + +Building for ROCm +~~~~~~~~~~~~~~~~~ + +The ROCm backends can be enabled with ``ENABLE_ROCBLAS_BACKEND`` and +``ENABLE_ROCRAND_BACKEND``. + +The target architecture must be set using the ``HIPSYCL_TARGETS`` parameter. See +the `AdaptiveCpp documentation +`_. +For example, to target the MI200 series, set ``-DHIPSYCL_TARGETS=hip:gfx90a``. +Multiple architectures can be enabled using a comma separated list. For example, +``-DHIPSYCL_TARGETS=hip:gfx906,gfx90a``, and multiple APIs with a semicolon +(``-DHIPSYCL_TARGETS=omp\;hip:gfx906,gfx90a``). + +For common AMD GPU architectures, see the :ref:`build_for_ROCM_dpcpp` in the +DPC++ build guide. + +.. _project_cleanup: + +Project Cleanup +############### + +Most use-cases involve building the project without the need to clean up the +build directory. However, if you wish to clean up the build directory, you can +delete the ``build`` folder and create a new one. If you wish to clean up the +build files but retain the build configuration, following commands will help you +do so. + +.. code-block:: sh + + # If you use "GNU/Unix Makefiles" for building, + make clean + + # If you use "Ninja" for building + ninja -t clean diff --git a/docs/building_the_project_with_dpcpp.rst b/docs/building_the_project_with_dpcpp.rst new file mode 100644 index 000000000..365028237 --- /dev/null +++ b/docs/building_the_project_with_dpcpp.rst @@ -0,0 +1,475 @@ +.. _building_the_project_with_dpcpp: + +Building the Project with DPC++ +=============================== + +This page describes building the oneMKL Interfaces with either the Intel(R) +oneAPI DPC++ Compiler or open-source oneAPI DPC++ Compiler. For guidance on +building the project with AdaptiveCpp, see +:ref:`building_the_project_with_adaptivecpp`. + +.. _build_setup_with_dpcpp: + +Environment Setup +################## + +#. + Install the required DPC++ compiler (Intel(R) DPC++ or Open DPC++ - see + :ref:`Selecting a Compiler`). + +#. + Clone this project. The root directory of the cloned repository will be + referred to as ````. + +#. + Build and install all `required dependencies + `_. + +.. _build_introduction_with_dpcpp: + +Build Commands +############### + +The build commands for various compilers and backends differ mostly in setting +the values of CMake options for compiler and backend. In this section, we +describe the common build commands. We will discuss backend-specific details in +the `Backends`_ section and provide examples in `CMake invocation examples`_. + +On Linux, the common form of the build command looks as follows (see `Building +for Windows`_ for building on Windows): + +.. code-block:: bash + + # Inside + mkdir build && cd build + cmake .. -DCMAKE_CXX_COMPILER=$CXX_COMPILER \ # Should be icpx or clang++ + -DCMAKE_C_COMPILER=$C_COMPILER \ # Should be icx or clang + -DENABLE_MKLGPU_BACKEND=False \ # Optional: The MKLCPU backend is True by default. + -DENABLE_MKLGPU_BACKEND=False \ # Optional: The MKLGPU backend is True by default. + -DENABLE__BACKEND=True \ # Enable any other backend(s) (optional) + -DENABLE__BACKEND=True \ # Multiple backends can be enabled at once. + -DBUILD_FUNCTIONAL_TESTS=False \ # See page *Building and Running Tests* for more on building tests. True by default. + -DBUILD_EXAMPLES=False # Optional: True by default. + cmake --build . + cmake --install . --prefix # required to have full package structure + +In the above, the ``$CXX_COMPILER`` and ``$C_COMPILER`` should be set to +``icpx`` and ``icx`` respectively when using the Intel(R) oneAPI DPC++ Compiler, +or ``clang++`` and ``clang`` respectively when using the Open DPC++ Compiler. + +Backends should be enabled by setting ``-DENABLE__BACKEND=True`` for +each desired backend. By default, only the ``MKLGPU`` and ``MKLCPU`` backends +are enabled. Multiple backends for multiple device vendors can be enabled at +once (albeit with limitations when using portBLAS and portFFT). The supported +backends for the compilers are given in the table at `oneMKL supported +configurations table +`_, +and the CMake option names are given in the table below. Some backends may +require additional parameters to be set. See the relevant section below for +additional guidance. + +If a backend library supports multiple domains (i.e., BLAS, LAPACK, DFT, RNG, +sparse BLAS), it may be desirable to only enable selected domains. For this, the +``TARGET_DOMAINS`` variable should be set. See the section `TARGET_DOMAINS`_. + +By default, the library also additionally builds examples and tests. These can +be disabled by setting the parameters ``BUILD_FUNCTIONAL_TESTS`` and +``BUILD_EXAMPLES`` to ``False``. Building the functional tests requires +additional external libraries for the BLAS and LAPACK domains. See the section +:ref:`building_and_running_tests` for more information. + +The most important supported build options are: + +.. list-table:: + :header-rows: 1 + + * - CMake Option + - Supported Values + - Default Value + * - ENABLE_MKLCPU_BACKEND + - True, False + - True + * - ENABLE_MKLGPU_BACKEND + - True, False + - True + * - ENABLE_CUBLAS_BACKEND + - True, False + - False + * - ENABLE_CUSOLVER_BACKEND + - True, False + - False + * - ENABLE_CUFFT_BACKEND + - True, False + - False + * - ENABLE_CURAND_BACKEND + - True, False + - False + * - ENABLE_NETLIB_BACKEND + - True, False + - False + * - ENABLE_ROCBLAS_BACKEND + - True, False + - False + * - ENABLE_ROCFFT_BACKEND + - True, False + - False + * - ENABLE_ROCSOLVER_BACKEND + - True, False + - False + * - ENABLE_ROCRAND_BACKEND + - True, False + - False + * - ENABLE_MKLCPU_THREAD_TBB + - True, False + - True + * - ENABLE_PORTBLAS_BACKEND + - True, False + - False + * - ENABLE_PORTFFT_BACKEND + - True, False + - False + * - BUILD_FUNCTIONAL_TESTS + - True, False + - True + * - BUILD_EXAMPLES + - True, False + - True + * - TARGET_DOMAINS (list) + - blas, lapack, rng, dft, sparse_blas + - All domains + +Some additional build options are given in the section `Additional build options`_. + +.. _build_target_domains: + +TARGET_DOMAINS +^^^^^^^^^^^^^^ + +oneMKL supports multiple domains: BLAS, DFT, LAPACK, RNG and sparse BLAS. The +domains built by oneMKL can be selected using the ``TARGET_DOMAINS`` parameter. +In most cases, ``TARGET_DOMAINS`` is set automatically according to the domains +supported by the backend libraries enabled. However, while most backend +libraries support only one of these domains, but some may support multiple. For +example, the ``MKLCPU`` backend supports every domain. To enable support for +only the BLAS domain in the oneMKL Interfaces whilst compiling with ``MKLCPU``, +``TARGET_DOMAINS`` could be set to ``blas``. To enable BLAS and DFT, +``-DTARGET_DOMAINS="blas dft"`` would be used. + + +Backends +######### + +.. _build_for_intel_onemkl_dpcpp: + +Building for Intel(R) oneMKL +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The Intel(R) oneMKL backend supports multiple domains on both x86 CPUs and Intel +GPUs. The MKLCPU backend using Intel(R) oneMKL for x86 CPU is enabled by +default, and controlled with the parameter ``ENABLE_MKLCPU_BACKEND``. The MKLGPU +backend using Intel(R) oneMKL for Intel GPU is enabled by default, and +controlled with the parameter ``ENABLE_MKLGPU_BACKEND``. + +When using the Intel(R) oneAPI DPC++ Compiler, it is likely that Intel(R) oneMKL +will be found automatically. If it is not, the parameter ``MKL_ROOT`` can be set +to point to the installation prefix of Intel(R) oneMKL. Alternatively, the +``MKLROOT`` environment variable can be set, either manually or by using an +environment script provided by the package. + + +.. _build_for_CUDA_dpcpp: + +Building for CUDA +^^^^^^^^^^^^^^^^^ + +The CUDA backends can be enabled with ``ENABLE_CUBLAS_BACKEND``, +``ENABLE_CUFFT_BACKEND``, ``ENABLE_CURAND_BACKEND``, and +``ENABLE_CUSOLVER_BACKEND``. + +No additional parameters are required for using CUDA libraries. In most cases, +the CUDA libraries should be found automatically by CMake. + +.. _build_for_ROCM_dpcpp: + +Building for ROCm +^^^^^^^^^^^^^^^^^ + +The ROCm backends can be enabled with ``ENABLE_ROCBLAS_BACKEND``, +``ENABLE_ROCFFT_BACKEND``, ``ENABLE_ROCSOLVER_BACKEND`` and +``ENABLE_ROCRAND_BACKEND``. + +For *RocBLAS*, *RocSOLVER* and *RocRAND*, the target device architecture must be +set. This can be set with using the ``HIP_TARGETS`` parameter. For example, to +enable a build for MI200 series GPUs, ``-DHIP_TARGETS=gfx90a`` should be set. +Currently, DPC++ can only build for a single HIP target at a time. This may +change in future versions. + +A few often-used architectures are listed below: + +.. list-table:: + :header-rows: 1 + + * - Architecture + - AMD GPU name + * - gfx90a + - AMD Instinct(TM) MI210/250/250X Accelerator + * - gfx908 + - AMD Instinct(TM) MI 100 Accelerator + * - gfx906 + - | AMD Radeon Instinct(TM) MI50/60 Accelerator + | AMD Radeon(TM) (Pro) VII Graphics Card + * - gfx900 + - | Radeon Instinct(TM) MI 25 Accelerator + | Radeon(TM) RX Vega 64/56 Graphics + +For a host with ROCm installed, the device architecture can be retrieved via the +``rocminfo`` tool. The architecture will be displayed in the ``Name:`` row. + +.. _build_for_portlibs_dpcpp: + +Pure SYCL backends: portBLAS and portFFT +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +`portBLAS `_ and `portFFT +`_ are experimental pure-SYCL +backends that work on all SYCL targets supported by the DPC++ compiler. Since +they support multiple targets, they cannot be enabled with other backends in the +same domain, or the MKLCPU or MKLGPU backends. Both libraries are experimental +and currently only support a subset of operations and features. + +For best performance, both libraries must be tuned. See the individual sections +for more details. + +Both portBLAS and portFFT are used as header-only libraries, and will be +downloaded automatically if not found. + +.. _build_for_portblas_dpcpp: + +Building for portBLAS +--------------------- + +`portBLAS `_ is +enabled by setting ``-DENABLE_PORTBLAS_BACKEND=True``. + +By default, the portBLAS backend is not tuned for any specific device. +This tuning is required to achieve best performance. +portBLAS can be tuned for a specific hardware target by adding compiler +definitions in 2 ways: + +#. + Manually specify a tuning target with ``-DPORTBLAS_TUNING_TARGET=``. + The list of portBLAS targets can be found + `here `_. + This will automatically set ``-fsycl-targets`` if needed. +#. + If one target is set via ``-fsycl-targets`` the configuration step will + try to automatically detect the portBLAS tuning target. One can manually + specify ``-fsycl-targets`` via ``CMAKE_CXX_FLAGS``. See + `DPC++ User Manual `_ + for more information on ``-fsycl-targets``. + +portBLAS relies heavily on JIT compilation. This may cause time-outs on some +systems. To avoid this issue, use ahead-of-time compilation through tuning +targets or ``sycl-targets``. + +.. _build_for_portfft_dpcpp: + +Building for portFFT +--------------------- + +`portFFT `_ is enabled by setting +``-DENABLE_PORTFFT_BACKEND=True``. + +By default, the portFFT backend is not tuned for any specific device. The tuning +flags are detailed in the `portFFT +`_ repository, and can set at +configuration time. Note that some tuning configurations may be incompatible +with some targets. + +The portFFT library is compiled using the same ``-fsycl-targets`` as specified +by the ``CMAKE_CXX_FLAGS``. If none are found, it will compile for +``-fsycl-targets=spir64``, and -if the compiler supports it- +``nvptx64-nvidia-cuda``. To enable HIP targets, ``HIP_TARGETS`` must be +specified. See `DPC++ User Manual +`_ for more information on +``-fsycl-targets``. + +.. _build_additional_options_dpcpp: + +Additional Build Options +########################## + +When building oneMKL the SYCL implementation can be specified by setting the +``ONEMKL_SYCL_IMPLEMENTATION`` option. Possible values are: + +* ``dpc++`` (default) for the `Intel(R) oneAPI DPC++ Compiler + `_ and for the `oneAPI + DPC++ Compiler `_ compilers. +* ``hipsycl`` for the `AdaptiveCpp `_ + SYCL implementation. +Please see :ref:`building_the_project_with_adaptivecpp` if using this option. + +The following table provides details of CMake options and their default values: + +.. list-table:: + :header-rows: 1 + + * - CMake Option + - Supported Values + - Default Value + * - BUILD_SHARED_LIBS + - True, False + - True + * - BUILD_DOC + - True, False + - False + + +.. note:: + When building with ``clang++`` for AMD backends, you must additionally set + ``ONEAPI_DEVICE_SELECTOR`` to ``hip:gpu`` and provide ``-DHIP_TARGETS`` + according to the targeted hardware. This backend has only been tested for the + ``gfx90a`` architecture (MI210) at the time of writing. + +.. note:: + When building with ``BUILD_FUNCTIONAL_TESTS=True`` (default option) only single CUDA backend can be built + (`#270 `_). + + +.. _build_invocation_examples_dpcpp: + +CMake invocation examples +########################## + +Build oneMKL with support for Nvidia GPUs with tests +disabled using the Ninja build system: + +.. code-block:: bash + + cmake $ONEMKL_DIR \ + -GNinja \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DENABLE_MKLGPU_BACKEND=False \ + -DENABLE_MKLCPU_BACKEND=False \ + -DENABLE_CUFFT_BACKEND=True \ + -DENABLE_CUBLAS_BACKEND=True \ + -DENABLE_CUSOLVER_BACKEND=True \ + -DENABLE_CURAND_BACKEND=True \ + -DBUILD_FUNCTIONAL_TESTS=False + +``$ONEMKL_DIR`` points at the oneMKL source directly. The x86 CPU (``MKLCPU``) +and Intel GPU (``MKLGPU``) backends are enabled by default, but are disabled +here. The backends for Nvidia GPUs must all be explicilty enabled. The tests are +disabled, but the examples will still be built. + +Building oneMKL with support for AMD GPUs with tests +disabled: + +.. code-block:: bash + + cmake $ONEMKL_DIR \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER=clang \ + -DENABLE_MKLCPU_BACKEND=False \ + -DENABLE_MKLGPU_BACKEND=False \ + -DENABLE_ROCFFT_BACKEND=True \ + -DENABLE_ROCBLAS_BACKEND=True \ + -DENABLE_ROCSOLVER_BACKEND=True \ + -DHIP_TARGETS=gfx90a \ + -DBUILD_FUNCTIONAL_TESTS=False + +``$ONEMKL_DIR`` points at the oneMKL source directly. The x86 CPU (``MKLCPU``) +and Intel GPU (``MKLGPU``) backends are enabled by default, but are disabled +here. The backends for AMD GPUs must all be explicilty enabled. The tests are +disabled, but the examples will still be built. + + +Build oneMKL for the DFT domain only with support for x86 CPU, Intel GPU, AMD +GPU and Nvidia GPU with testing enabled: + +.. code-block:: bash + + cmake $ONEMKL_DIR \ + -DCMAKE_CXX_COMPILER=icpx \ + -DCMAKE_C_COMPILER=icx \ + -DENABLE_ROCFFT_BACKEND=True \ + -DENABLE_CUFFT_BACKEND=True \ + -DTARGET_DOMAINS=dft \ + -DBUILD_EXAMPLES=False + +Note that this is not a supported configuration, and requires Codeplay's oneAPI +for `AMD `_ and +`Nvidia `_ GPU +plugins. The MKLCPU and MKLGPU backends are enabled by +default, with backends for Nvidia GPU and AMD GPU explicitly enabled. +``-DTARGET_DOMAINS=dft`` causes only DFT backends to be built. If this was not +set, the backend libraries to enable the use of BLAS, LAPACK and RNG with MKLGPU +and MKLCPU would also be enabled. The build of examples is disabled. Since +functional testing was not disabled, tests would be built. + +.. _project_cleanup: + +Project Cleanup +############### + +Most use-cases involve building the project without the need to clean up the +build directory. However, if you wish to clean up the build directory, you can +delete the ``build`` folder and create a new one. If you wish to clean up the +build files but retain the build configuration, following commands will help you +do so. + +.. code-block:: sh + + # If you use "GNU/Unix Makefiles" for building, + make clean + + # If you use "Ninja" for building + ninja -t clean + + +.. _build_for_windows_dpcpp: + +Building for Windows +#################### + +The Windows build is similar to the Linux build, albeit that `fewer backends are +supported `_. +Additionally, the Ninja build system must be used. For example: + +.. code-block:: bash + + # Inside + md build && cd build + cmake .. -G Ninja [-DCMAKE_CXX_COMPILER=\bin\icx] # required only if icx is not found in environment variable PATH + [-DCMAKE_C_COMPILER=\bin\icx] # required only if icx is not found in environment variable PATH + [-DMKL_ROOT=] # required only if environment variable MKLROOT is not set + [-DREF_BLAS_ROOT=] # required only for testing + [-DREF_LAPACK_ROOT=] # required only for testing + ninja + ctest + cmake --install . --prefix # required to have full package structure + +.. _build_common_problems_dpcpp: + +Build FAQ +######### + +clangrt builtins lib not found + Encountered when trying to build oneMKL with some ROCm libraries. There are + several possible solutions: * If building Open DPC++ from source, add + ``compiler-rt`` to the external projects compile option: + ``--llvm-external-projects compiler-rt``. * The *clangrt* from ROCm can be + used, depending on ROCm version: ``export + LIBRARY_PATH=/path/to/rocm-$rocm-version$/llvm/lib/clang/$clang-version$/lib/linux/:$LIBRARY_PATH`` + +Could NOT find CBLAS (missing: CBLAS file) + Encountered when tests are enabled along with the BLAS domain. The tests + require a reference BLAS implementation, but cannot find one. Either install + or build a BLAS library and set ``-DREF_BLAS_ROOT``` as described in + :ref:`building_and_running_tests`. Alternatively, the tests can be disabled by + setting ``-DBUILD_FUNCTIONAL_TESTS=False``. + +error: invalid target ID ''; format is a processor name followed by an optional colon-delimited list of features followed by an enable/disable sign (e.g.,'gfx908:sramecc+:xnack-') + The HIP_TARGET has not been set. Please see `Building for ROCm`_. + diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index 532ab6ffa..000000000 --- a/docs/conf.py +++ /dev/null @@ -1,201 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Configuration file for the Sphinx documentation builder. -# -# This file does only contain a selection of the most common options. For a -# full list see the documentation: -# http://www.sphinx-doc.org/en/master/config - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, as shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) - - -# -- Project information ----------------------------------------------------- - -project = 'oneAPI Math Kernel Library Interfaces' -copyright = '2020-2022, Intel Corporation' -author = 'Intel Corporation' - -# The short X.Y version -version = '' -# The full version, including alpha/beta/rc tags -release = '0.1' - - -# -- General configuration --------------------------------------------------- - -# If your documentation needs a minimal Sphinx version, state it here. -# -# needs_sphinx = '1.0' - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - 'sphinx.ext.autodoc', - 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', - 'sphinx.ext.todo', -] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ['_templates'] - -# The suffix(es) of source filenames. -# You can specify multiple suffix as a list of string: -# -# source_suffix = ['.rst', '.md'] -source_suffix = '.rst' - -# The master toctree document. -master_doc = 'index' - -# The language for content autogenerated by Sphinx. Refer to documentation -# for a list of supported languages. -# -# This is also used if you do content translation via gettext catalogs. -# Usually you set "language" from the command line for these cases. -language = None - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] - -# The name of the Pygments (syntax highlighting) style to use. -pygments_style = None - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = 'sphinx_book_theme' -html_logo = '_static/oneAPI-rgb-rev-100.png' -html_favicon = '_static/favicons.png' - -# Theme options are theme-specific and customize the look and feel of a theme -# further. For a list of options available for each theme, see the -# documentation. -# - -# Theme options -html_theme_options = { -'repository_url': 'https://github.com/oneapi-src/oneMKL', -'path_to_docs': 'docs', -'use_issues_button': True, -'use_edit_page_button': True, -'repository_branch': 'develop', -'extra_footer': '

Cookies

' -} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] - -# Custom sidebar templates, must be a dictionary that maps document names -# to template names. -# -# The default sidebars (for documents that don't match any pattern) are -# defined by theme itself. Builtin themes are using these templates by -# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', -# 'searchbox.html']``. -# -# html_sidebars = {} - - -# -- Options for HTMLHelp output --------------------------------------------- - -# Output file base name for HTML help builder. -htmlhelp_basename = 'reSTTemplatedoc' - - -# -- Options for LaTeX output ------------------------------------------------ - -latex_elements = { - # The paper size ('letterpaper' or 'a4paper'). - # - # 'papersize': 'letterpaper', - - # The font size ('10pt', '11pt' or '12pt'). - # - # 'pointsize': '10pt', - - # Additional stuff for the LaTeX preamble. - # - # 'preamble': '', - - # Latex figure (float) alignment - # - # 'figure_align': 'htbp', -} - -# Grouping the document tree into LaTeX files. List of tuples -# (source start file, target name, title, -# author, documentclass [howto, manual, or own class]). -latex_documents = [ - (master_doc, 'reSTTemplate.tex', 'reST Template Documentation', - 'Ben Fitch', 'manual'), -] - - -# -- Options for manual page output ------------------------------------------ - -# One entry per manual page. List of tuples -# (source start file, name, description, authors, manual section). -man_pages = [ - (master_doc, 'resttemplate', 'reST Template Documentation', - [author], 1) -] - - -# -- Options for Texinfo output ---------------------------------------------- - -# Grouping the document tree into Texinfo files. List of tuples -# (source start file, target name, title, author, -# dir menu entry, description, category) -texinfo_documents = [ - (master_doc, 'reSTTemplate', 'reST Template Documentation', - author, 'reSTTemplate', 'One line description of project.', - 'Miscellaneous'), -] - - -# -- Options for Epub output ------------------------------------------------- - -# Bibliographic Dublin Core info. -epub_title = project - -# The unique identifier of the text. This can be a ISBN number -# or the project homepage. -# -# epub_identifier = '' - -# A unique identification for the text. -# -# epub_uid = '' - -# A list of files that should not be packed into the epub file. -epub_exclude_files = ['search.html'] - - -# -- Extension configuration ------------------------------------------------- - -# -- Options for intersphinx extension --------------------------------------- - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} - -# -- Options for todo extension ---------------------------------------------- - -# If true, `todo` and `todoList` produce output, else they produce nothing. -todo_include_todos = True diff --git a/docs/conf.py.in b/docs/conf.py.in index 793c9b87c..d874dbab7 100644 --- a/docs/conf.py.in +++ b/docs/conf.py.in @@ -10,7 +10,7 @@ # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. +# documentation root, use os.path.abspath to make it absolute, as shown here. # # import os # import sys @@ -20,7 +20,7 @@ # -- Project information ----------------------------------------------------- project = 'oneAPI Math Kernel Library Interfaces' -copyright = '2020, Intel Corporation' +copyright = '2020-2022, Intel Corporation' author = 'Intel Corporation' # The short X.Y version @@ -41,7 +41,6 @@ release = '0.1' extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.doctest', - 'sphinx.ext.intersphinx', 'sphinx.ext.todo', ] @@ -62,7 +61,7 @@ master_doc = 'index' # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. -language = None +language = 'en' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. @@ -72,24 +71,38 @@ exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None +static_dir = '@CMAKE_CURRENT_SOURCE_DIR@/_static' + # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # -html_theme = 'alabaster' +html_theme = 'sphinx_book_theme' +html_logo = f'{static_dir}/oneAPI-rgb-rev-100.png' +html_favicon = f'{static_dir}/favicons.png' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # -# html_theme_options = {} + +# Theme options +html_theme_options = { +'repository_url': 'https://github.com/oneapi-src/oneMKL', +'path_to_docs': 'docs', +'use_issues_button': True, +'use_edit_page_button': True, +'repository_branch': 'develop', +'extra_footer': '

Cookies

', +'navigation_with_keys': False, +} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ['_static'] +html_static_path = [static_dir] # Custom sidebar templates, must be a dictionary that maps document names # to template names. @@ -179,11 +192,6 @@ epub_exclude_files = ['search.html'] # -- Extension configuration ------------------------------------------------- -# -- Options for intersphinx extension --------------------------------------- - -# Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} - # -- Options for todo extension ---------------------------------------------- # If true, `todo` and `todoList` produce output, else they produce nothing. diff --git a/docs/index.rst b/docs/index.rst index e1a051524..51e4216ee 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,5 +1,5 @@ .. - Copyright 20202022 Intel Corporation + Copyright 2020-2024 Intel Corporation .. _onemkl: @@ -21,7 +21,10 @@ Contents :maxdepth: 2 selecting_a_compiler.rst - building_the_project.rst + building_the_project_with_dpcpp.rst + building_the_project_with_adaptivecpp.rst + building_and_running_tests.rst + using_onemkl_with_cmake.rst .. toctree:: :caption: Developer Reference diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 000000000..8365d7241 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,28 @@ +accessible-pygments==0.0.5 +alabaster==0.7.16 +Babel==2.15.0 +beautifulsoup4==4.12.3 +certifi==2024.7.4 +charset-normalizer==3.3.2 +docutils==0.21.2 +idna==3.7 +imagesize==1.4.1 +Jinja2==3.1.4 +MarkupSafe==2.1.5 +packaging==24.0 +pydata-sphinx-theme==0.15.2 +Pygments==2.18.0 +requests==2.32.1 +snowballstemmer==2.2.0 +soupsieve==2.5 +Sphinx==7.3.7 +sphinx-book-theme==1.1.2 +sphinxcontrib-applehelp==1.0.8 +sphinxcontrib-devhelp==1.0.6 +sphinxcontrib-htmlhelp==2.0.5 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==1.0.7 +sphinxcontrib-serializinghtml==1.1.10 +tomli==2.0.1 +typing_extensions==4.11.0 +urllib3==2.2.2 diff --git a/docs/selecting_a_compiler.rst b/docs/selecting_a_compiler.rst index 78ea3fb0d..8c09e60b4 100644 --- a/docs/selecting_a_compiler.rst +++ b/docs/selecting_a_compiler.rst @@ -3,17 +3,17 @@ Selecting a Compiler ==================== -You must choose a compiler according to the required backend of your +You must choose a compiler according to the required backend and the operating system of your application. * If your application requires Intel GPU, use `Intel(R) oneAPI DPC++ Compiler `_ ``icpx`` on Linux or ``icx`` on Windows. -* If your application requires NVIDIA GPU, use the latest release of - ``clang++`` from `Intel project for LLVM* technology `_ or use ``hipSYCL`` from the `hipSYCL repository `_ (except for LAPACK domain). -* If your application requires AMD GPU, use ``hipSYCL`` or use the latest release of ``clang++`` from `Intel project for LLVM* technology `_. -* If no Intel GPU, NVIDIA GPU, or AMD GPU is required, on Linux you can use either +* If your Linux application requires NVIDIA GPU, build ``clang++`` from the latest source of + `oneAPI DPC++ Compiler `_ with `support for NVIDIA CUDA `_ or use ``hipSYCL`` from the `hipSYCL repository `_ (except for LAPACK domain). +* If your Linux application requires AMD GPU, build ``clang++`` from the latest source of `oneAPI DPC++ Compiler `_ with `support for HIP AMD `_ or use ``hipSYCL``. +* If no Intel GPU, NVIDIA GPU, or AMD GPU is required, on Linux you can use `Intel(R) oneAPI DPC++ Compiler `_ - ``icpx``, ``clang++``, or ``hipSYCL`` and on Windows you can use either + ``icpx``, `oneAPI DPC++ Compiler `_ ``clang++``, or ``hipSYCL``, + and on Windows you can use either `Intel(R) oneAPI DPC++ Compiler `_ - ``icx``, or ``clang-cl`` from - `Intel project for LLVM* technology `_. + ``icx`` or `oneAPI DPC++ Compiler `_ ``clang-cl``. diff --git a/docs/using_onemkl_with_cmake.rst b/docs/using_onemkl_with_cmake.rst new file mode 100644 index 000000000..5fb497362 --- /dev/null +++ b/docs/using_onemkl_with_cmake.rst @@ -0,0 +1,102 @@ +.. _using_onemkl_interface_library_with_cmake: + +Using the oneMKL Interfaces in your project with CMake +============================================================= + +The CMake build tool can help you use oneMKL Interfaces in your own project. +Instead of manually linking and including directories, you can use the CMake targets +exported by the oneMKL Interfaces project. You can use oneMKL in one of two +forms, with the target names depending on the approach taken: + +* you can use a previously installed copy, either from a binary distribution or + built from source. This can be imported using CMake's ``find_package`` + command. See the section `using_from_installed_binary`_. +* or you can have CMake automatically download and build oneMKL as part of the + build process using CMake's FetchContent_ functionality. + See the section `using_with_fetchcontent`_. + + +.. _using_from_installed_binary: + +Using an installed oneMKL Interfaces +#################################### + +If the oneMKL Interfaces have been previously installed, either by building from +source or as a distributed binary, they can be consumed using CMake using +``find_package(oneMKL REQUIRED)``. The compiler used for the target library or +application should match that used to build oneMKL Interfaces. + +For example: + +.. code-block:: cmake + + find_package(oneMKL REQUIRED) + target_link_libraries(myTarget PRIVATE MKL::onemkl) + +Different targets can be used depending on the requirements of oneMKL. +To link against the entire library, the ``MKL::onemkl`` target should be used. +For specific domains, ``MKL::onemkl_`` should be used. +And for specific backends, ``MKL::onemkl__`` should be used. + +When using a binary, it may be useful to know the backends that were enabled +during the build. To check for the existence of backends, CMake's ``if(TARGET +)`` construct can be used. For example, with the ``cufft`` backend: + +.. code-block:: cmake + + if(TARGET MKL::onemkl_dft_cufft) + target_link_libraries(myTarget PRIVATE MKL::onemkl_dft_cufft) + else() + message(FATAL_ERROR "oneMKL Interfaces was not built with CuFFT backend") + endif() + + +If oneMKL Interfaces has been installed to a non-standard location, the +operating system may not find the backend libraries when they're lazily loaded +at runtime. To make sure they're found you may need to set +``LD_LIBRARY_PATH=/lib:$LD_LIBRARY_PATH`` on Linux. + +.. _using_with_fetchcontent: + +Using CMake's FetchContent +########################## + + +The FetchContent_ functionality of CMake can be used to download, build and +install oneMKL Interfaces as part of the build. + +For example: + +.. code-block:: cmake + + include(FetchContent) + set(BUILD_FUNCTIONAL_TESTS False) + set(BUILD_EXAMPLES False) + set(ENABLE__BACKEND True) + FetchContent_Declare( + onemkl_interface_library + GIT_REPOSITORY https://github.com/oneapi-src/oneMKL.git + GIT_TAG develop + ) + FetchContent_MakeAvailable(onemkl_interface_library) + + target_link_libraries(myTarget PRIVATE onemkl) + +The build parameters should be appropriately set before +``FetchContent_Declare``. See :ref:`building_the_project_with_dpcpp` or +:ref:`building_the_project_with_adaptivecpp`. + +To link against the main library with run-time dispatching, use the target +``onemkl``. To link against particular domains, use the target +``onemkl_``. For example, ``onemkl_blas`` or ``onemkl_dft``. To link +against particular backends (as required for static dispatch of oneAPI calls to +a particular backend), use the target ``onemkl__``. For +example, ``onemkl_dft_cufft``. + +When using the run-time dispatch mechanism, it is likely that the operating +system will not find the backend libraries when they're loaded at runtime. To +make sure they're found you may need to set +``LD_LIBRARY_PATH=/lib:$LD_LIBRARY_PATH`` on Linux. + + +.. _FetchContent: https://cmake.org/cmake/help/latest/module/FetchContent.html diff --git a/examples/README.md b/examples/README.md index 58fec1f6e..0dad8772d 100644 --- a/examples/README.md +++ b/examples/README.md @@ -3,26 +3,26 @@ oneAPI Math Kernel Library (oneMKL) Interfaces offers examples with the followin - blas: level3/gemm_usm - rng: uniform_usm - lapack: getrs_usm -- dft: complex_fwd_buffer, real_fwd_usm +- dft: complex_fwd_usm, real_fwd_usm +- sparse_blas: sparse_gemv_usm Each routine has one run-time dispatching example and one compile-time dispatching example (which uses both mklcpu and cuda backends), located in `example/<$domain>/run_time_dispatching` and `example/<$domain>/compile_time_dispatching` subfolders, respectively. To build examples, use cmake build option `-DBUILD_EXAMPLES=true`. Compile_time_dispatching will be built if `-DBUILD_EXAMPLES=true` and cuda backend is enabled, because the compile-time dispatching example runs on both mklcpu and cuda backends. Run_time_dispatching will be built if `-DBUILD_EXAMPLES=true` and `-DBUILD_SHARED_LIBS=true`. -All DFT examples require the mklgpu backend to be enabled. The example executable naming convention follows `example_<$domain>_<$routine>_<$backend>` for compile-time dispatching examples or `example_<$domain>_<$routine>` for run-time dispatching examples. E.g. `example_blas_gemm_usm_mklcpu_cublas ` `example_blas_gemm_usm` -## Example outputs (blas, rng, lapack, dft) +## Example outputs (blas, rng, lapack, dft, sparse_blas) ## blas Run-time dispatching examples with mklcpu backend ``` -$ export SYCL_DEVICE_FILTER=cpu +$ export ONEAPI_DEVICE_SELECTOR="opencl:cpu" $ ./bin/example_blas_gemm_usm ######################################################################## @@ -39,8 +39,8 @@ $ ./bin/example_blas_gemm_usm # Using single precision (float) data type # # Device will be selected during runtime. -# The environment variable SYCL_DEVICE_FILTER can be used to specify -# SYCL device +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices # ######################################################################## @@ -75,7 +75,7 @@ BLAS GEMM USM example ran OK. ``` Run-time dispatching examples with mklgpu backend ``` -$ export SYCL_DEVICE_FILTER=gpu +$ export ONEAPI_DEVICE_SELECTOR="level_zero:gpu" $ ./bin/example_blas_gemm_usm ######################################################################## @@ -92,8 +92,8 @@ $ ./bin/example_blas_gemm_usm # Using single precision (float) data type # # Device will be selected during runtime. -# The environment variable SYCL_DEVICE_FILTER can be used to specify -# SYCL device +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices # ######################################################################## @@ -187,7 +187,7 @@ BLAS GEMM USM example ran OK on MKLCPU and CUBLAS ## lapack Run-time dispatching example with mklgpu backend: ``` -$ export SYCL_DEVICE_FILTER=gpu +$ export ONEAPI_DEVICE_SELECTOR="level_zero:gpu" $ ./bin/example_lapack_getrs_usm ######################################################################## @@ -205,8 +205,8 @@ $ ./bin/example_lapack_getrs_usm # Using single precision (float) data type # # Device will be selected during runtime. -# The environment variable SYCL_DEVICE_FILTER can be used to specify -# SYCL device +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices # ######################################################################## @@ -288,7 +288,7 @@ LAPACK GETRS USM example ran OK on MKLCPU and CUSOLVER ## rng Run-time dispatching example with mklgpu backend: ``` -$ export SYCL_DEVICE_FILTER=gpu +$ export ONEAPI_DEVICE_SELECTOR="level_zero:gpu" $ ./bin/example_rng_uniform_usm ######################################################################## @@ -301,8 +301,8 @@ $ ./bin/example_rng_uniform_usm # Using single precision (float) data type # # Device will be selected during runtime. -# The environment variable SYCL_DEVICE_FILTER can be used to specify -# SYCL device +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices # ######################################################################## @@ -353,10 +353,10 @@ Random number generator example with uniform distribution ran OK on MKLCPU and C ## dft -Compile-time dispatching example with mklgpu backend +Compile-time dispatching example with MKLGPU backend ```none -$ SYCL_DEVICE_FILTER=gpu ./bin/example_dft_complex_fwd_buffer_mklgpu +$ ONEAPI_DEVICE_SELECTOR="level_zero:gpu" ./bin/example_dft_complex_fwd_buffer_mklgpu ######################################################################## # Complex out-of-place forward transform for Buffer API's example: @@ -369,8 +369,8 @@ $ SYCL_DEVICE_FILTER=gpu ./bin/example_dft_complex_fwd_buffer_mklgpu # # For Intel GPU with Intel MKLGPU backend. # -# The environment variable SYCL_DEVICE_FILTER can be used to specify -# SYCL device +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices ######################################################################## Running DFT Complex forward out-of-place buffer example @@ -380,29 +380,218 @@ Running with single precision real data type on: DFT Complex USM example ran OK on MKLGPU ``` -Runtime dispatching example with both mklgpu backend +Runtime dispatching example with MKLGPU, cuFFT, rocFFT and portFFT backends: ```none -SYCL_DEVICE_FILTER=gpu ./bin/example_dft_complex_fwd_buffer_mklgpu +$ ONEAPI_DEVICE_SELECTOR="level_zero:gpu" ./bin/example_dft_real_fwd_usm ######################################################################## -# Complex out-of-place forward transform for Buffer API's example: +# DFT complex in-place forward transform with USM API example: # # Using APIs: -# Compile-time dispatch API -# Buffer forward complex out-of-place +# USM forward complex in-place +# Run-time dispatch # # Using single precision (float) data type # -# For Intel GPU with Intel MKLGPU backend. +# Device will be selected during runtime. +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices # -# The environment variable SYCL_DEVICE_FILTER can be used to specify -# SYCL device ######################################################################## -Running DFT Complex forward out-of-place buffer example -Using compile-time dispatch API with MKLGPU. -Running with single precision real data type on: - GPU device :Intel(R) UHD Graphics 750 [0x4c8a] -DFT Complex USM example ran OK on MKLGPU +Running DFT complex forward example on GPU device +Device name is: Intel(R) UHD Graphics 750 [0x4c8a] +Running with single precision real data type: +DFT example run_time dispatch +DFT example ran OK +``` + +```none +$ ONEAPI_DEVICE_SELECTOR="level_zero:gpu" ./bin/example_dft_real_fwd_usm + +######################################################################## +# DFT complex in-place forward transform with USM API example: +# +# Using APIs: +# USM forward complex in-place +# Run-time dispatch +# +# Using single precision (float) data type +# +# Device will be selected during runtime. +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices +# +######################################################################## + +Running DFT complex forward example on GPU device +Device name is: NVIDIA A100-PCIE-40GB +Running with single precision real data type: +DFT example run_time dispatch +DFT example ran OK +``` + +```none +$ ./bin/example_dft_real_fwd_usm + +######################################################################## +# DFT complex in-place forward transform with USM API example: +# +# Using APIs: +# USM forward complex in-place +# Run-time dispatch +# +# Using single precision (float) data type +# +# Device will be selected during runtime. +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices +# +######################################################################## + +Running DFT complex forward example on GPU device +Device name is: AMD Radeon PRO W6800 +Running with single precision real data type: +DFT example run_time dispatch +DFT example ran OK +``` + +```none +$ LD_LIBRARY_PATH=lib/:$LD_LIBRARY_PATH ./bin/example_dft_real_fwd_usm +######################################################################## +# DFT complex in-place forward transform with USM API example: +# +# Using APIs: +# USM forward complex in-place +# Run-time dispatch +# +# Using single precision (float) data type +# +# Device will be selected during runtime. +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices +# +######################################################################## + +Running DFT complex forward example on GPU device +Device name is: Intel(R) UHD Graphics 750 +Running with single precision real data type: +DFT example run_time dispatch +Unsupported Configuration: + oneMKL: dft/backends/portfft/commit: function is not implemented portFFT only supports complex to complex transforms +``` + +## sparse_blas + +Run-time dispatching examples with mklcpu backend +``` +$ export ONEAPI_DEVICE_SELECTOR="opencl:cpu" +$ ./bin/example_sparse_blas_gemv_usm + +######################################################################## +# Sparse Matrix-Vector Multiply Example: +# +# y = alpha * op(A) * x + beta * y +# +# where A is a sparse matrix in CSR format, x and y are dense vectors +# and alpha, beta are floating point type precision scalars. +# +# Using apis: +# sparse::gemv +# +# Using single precision (float) data type +# +# Device will be selected during runtime. +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices +# +######################################################################## + +Running Sparse BLAS GEMV USM example on CPU device. +Device name is: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz +Running with single precision real data type: + + sparse::gemv parameters: + transA = nontrans + nrows = 64 + alpha = 1, beta = 0 + + sparse::gemv example passed + Finished +Sparse BLAS GEMV USM example ran OK. +``` + +Run-time dispatching examples with mklgpu backend +``` +$ export ONEAPI_DEVICE_SELECTOR="level_zero:gpu" +$ ./bin/example_sparse_blas_gemv_usm + +######################################################################## +# Sparse Matrix-Vector Multiply Example: +# +# y = alpha * op(A) * x + beta * y +# +# where A is a sparse matrix in CSR format, x and y are dense vectors +# and alpha, beta are floating point type precision scalars. +# +# Using apis: +# sparse::gemv +# +# Using single precision (float) data type +# +# Device will be selected during runtime. +# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify +# available devices +# +######################################################################## + +Running Sparse BLAS GEMV USM example on GPU device. +Device name is: Intel(R) HD Graphics 530 [0x1912] +Running with single precision real data type: + + sparse::gemv parameters: + transA = nontrans + nrows = 64 + alpha = 1, beta = 0 + + sparse::gemv example passed + Finished +Sparse BLAS GEMV USM example ran OK. +``` + +Compile-time dispatching example with mklcpu backend +``` +$ export ONEAPI_DEVICE_SELECTOR="opencl:cpu" +$ ./bin/example_sparse_blas_gemv_usm_mklcpu + +######################################################################## +# Sparse Matrix-Vector Multiply Example: +# +# y = alpha * op(A) * x + beta * y +# +# where A is a sparse matrix in CSR format, x and y are dense vectors +# and alpha, beta are floating point type precision scalars. +# +# Using apis: +# sparse::gemv +# +# Using single precision (float) data type +# +# Running on Intel CPU device +# +######################################################################## + +Running Sparse BLAS GEMV USM example on CPU device. +Device name is: Intel(R) Core(TM) i7-6700K CPU @ 4.00GHz +Running with single precision real data type: + + sparse::gemv parameters: + transA = nontrans + nrows = 64 + alpha = 1, beta = 0 + + sparse::gemv example passed + Finished +Sparse BLAS GEMV USM example ran OK. ``` diff --git a/examples/blas/run_time_dispatching/level3/CMakeLists.txt b/examples/blas/run_time_dispatching/level3/CMakeLists.txt index 19f830e27..d0d35fc0d 100644 --- a/examples/blas/run_time_dispatching/level3/CMakeLists.txt +++ b/examples/blas/run_time_dispatching/level3/CMakeLists.txt @@ -17,8 +17,8 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== -# NOTE: user needs to set env var SYCL_DEVICE_FILTER to use runtime example without specifying backend in CMake -# $ENV{SYCL_DEVICE_FILTER} +# NOTE: user needs to set env var ONEAPI_DEVICE_SELECTOR to use runtime example without specifying backend in CMake +# $ENV{ONEAPI_DEVICE_SELECTOR} # Build object from all example sources @@ -26,13 +26,13 @@ set(BLAS_RT_SOURCES "gemm_usm") # Set up for the right backend for run-time dispatching examples # If users build more than one backend (i.e. mklcpu and mklgpu, or mklcpu and CUDA), they may need to -# overwrite SYCL_DEVICE_FILTER in their environment to run on the desired backend +# overwrite ONEAPI_DEVICE_SELECTOR in their environment to run on the desired backend set(DEVICE_FILTERS "") if(ENABLE_MKLCPU_BACKEND) - list(APPEND DEVICE_FILTERS "cpu") + list(APPEND DEVICE_FILTERS "opencl:cpu") endif() if(ENABLE_MKLGPU_BACKEND) - list(APPEND DEVICE_FILTERS "gpu") + list(APPEND DEVICE_FILTERS "level_zero:gpu") endif() if(ENABLE_CUBLAS_BACKEND) list(APPEND DEVICE_FILTERS "cuda:gpu") @@ -40,8 +40,21 @@ endif() if(ENABLE_ROCBLAS_BACKEND) list(APPEND DEVICE_FILTERS "hip:gpu") endif() +if(ENABLE_PORTBLAS_BACKEND) + if(PORTBLAS_TUNING_TARGET) + if(PORTBLAS_TUNING_TARGET MATCHES "INTEL_CPU") + list(APPEND DEVICE_FILTERS "opencl:cpu") + elseif(PORTBLAS_TUNING_TARGET MATCHES "_GPU") + list(APPEND DEVICE_FILTERS "*:gpu") + endif() + else() + # portBLAS default sycl-target is spir64, testing runtime on both supported + # devices. + list(APPEND DEVICE_FILTERS "opencl:cpu;level_zero:gpu") + endif() +endif() -message(STATUS "SYCL_DEVICE_FILTER will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") +message(STATUS "ONEAPI_DEVICE_SELECTOR will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") foreach(blas_rt_source ${BLAS_RT_SOURCES}) add_executable(example_${domain}_${blas_rt_source} ${blas_rt_source}.cpp) @@ -68,7 +81,7 @@ foreach(blas_rt_source ${BLAS_RT_SOURCES}) add_test(NAME ${domain}/EXAMPLE/RT/${blas_rt_source}/${device_filter} COMMAND example_${domain}_${blas_rt_source}) set_property(TEST ${domain}/EXAMPLE/RT/${blas_rt_source}/${device_filter} PROPERTY ENVIRONMENT LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib:$ENV{LD_LIBRARY_PATH} - ENVIRONMENT SYCL_DEVICE_FILTER=${device_filter}) + ENVIRONMENT ONEAPI_DEVICE_SELECTOR=${device_filter}) endforeach(device_filter) endforeach(blas_rt_source) diff --git a/examples/blas/run_time_dispatching/level3/gemm_usm.cpp b/examples/blas/run_time_dispatching/level3/gemm_usm.cpp index 0f021bcc3..cd59e7b7f 100644 --- a/examples/blas/run_time_dispatching/level3/gemm_usm.cpp +++ b/examples/blas/run_time_dispatching/level3/gemm_usm.cpp @@ -198,9 +198,9 @@ void print_example_banner() { std::cout << "# Using single precision (float) data type" << std::endl; std::cout << "# " << std::endl; std::cout << "# Device will be selected during runtime." << std::endl; - std::cout << "# The environment variable SYCL_DEVICE_FILTER can be used to specify" + std::cout << "# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify" << std::endl; - std::cout << "# SYCL device" << std::endl; + std::cout << "# available devices" << std::endl; std::cout << "# " << std::endl; std::cout << "########################################################################" << std::endl; diff --git a/examples/dft/compile_time_dispatching/CMakeLists.txt b/examples/dft/compile_time_dispatching/CMakeLists.txt index e6de95444..ed0ca2922 100644 --- a/examples/dft/compile_time_dispatching/CMakeLists.txt +++ b/examples/dft/compile_time_dispatching/CMakeLists.txt @@ -18,26 +18,33 @@ #=============================================================================== #Build object from all sources -set(DFTI_CT_SOURCES "") -if(ENABLE_MKLGPU_BACKEND) - list(APPEND DFTI_CT_SOURCES "complex_fwd_buffer_mklgpu") +set(DFT_CT_SOURCES "") +if (ENABLE_MKLCPU_BACKEND AND ENABLE_CUFFT_BACKEND) + list(APPEND DFT_CT_SOURCES "complex_fwd_usm_mklcpu_cufft") endif() -foreach(dfti_ct_sources ${DFTI_CT_SOURCES}) - add_executable(example_${domain}_${dfti_ct_sources} ${dfti_ct_sources}.cpp) - target_include_directories(example_${domain}_${dfti_ct_sources} +include(WarningsUtils) + +foreach(dft_ct_source ${DFT_CT_SOURCES}) + set(EXAMPLE_NAME example_${domain}_${dft_ct_source}) + add_executable(${EXAMPLE_NAME} ${dft_ct_source}.cpp) + target_include_directories(${EXAMPLE_NAME} PUBLIC ${PROJECT_SOURCE_DIR}/examples/include - PUBLIC ${PROJECT_SOURCE_DIR}/include PUBLIC ${CMAKE_BINARY_DIR}/bin ) - if(domain STREQUAL "dft" AND ENABLE_MKLGPU_BACKEND) - add_dependencies(example_${domain}_${dfti_ct_sources} onemkl_${domain}_mklgpu) - list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklgpu) + + if(domain STREQUAL "dft" AND ENABLE_MKLCPU_BACKEND AND ENABLE_CUFFT_BACKEND) + add_dependencies(${EXAMPLE_NAME} onemkl_${domain}_mklcpu onemkl_${domain}_cufft) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_mklcpu onemkl_${domain}_cufft) endif() - target_link_libraries(example_${domain}_${dfti_ct_sources} PUBLIC - ${ONEMKL_LIBRARIES_${domain}} - ONEMKL::SYCL::SYCL + + target_link_libraries(${EXAMPLE_NAME} PUBLIC + ${ONEMKL_LIBRARIES_${domain}} + onemkl_warnings ) + # Register example as ctest - add_test(NAME ${domain}/EXAMPLE/CT/${dfti_ct_sources} COMMAND example_${domain}_${dfti_ct_sources}) -endforeach(dfti_ct_sources) + add_test(NAME dft/EXAMPLE/CT/${dft_ct_source} COMMAND ${EXAMPLE_NAME}) + +endforeach(dft_ct_source) + diff --git a/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu_cufft.cpp similarity index 51% rename from examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp rename to examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu_cufft.cpp index b72c952dc..59c810f3f 100644 --- a/examples/dft/compile_time_dispatching/complex_fwd_buffer_mklgpu.cpp +++ b/examples/dft/compile_time_dispatching/complex_fwd_usm_mklcpu_cufft.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright 2023 Intel Corporation +* Copyright 2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,11 +27,27 @@ #include #endif #include "oneapi/mkl.hpp" +#include -void run_example(const sycl::device& gpu_device) { - constexpr int N = 10; +void run_example(const sycl::device& cpu_device, const sycl::device& gpu_device) { + constexpr std::size_t N = 10; // Catch asynchronous exceptions for cpu + auto cpu_error_handler = [&](sycl::exception_list exceptions) { + for (auto const& e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (sycl::exception const& e) { + // Handle not dft related exceptions that happened during asynchronous call + std::cerr << "Caught asynchronous SYCL exception on CPU device during execution:" + << std::endl; + std::cerr << "\t" << e.what() << std::endl; + } + } + std::exit(2); + }; + // Catch asynchronous exceptions for gpu auto gpu_error_handler = [&](sycl::exception_list exceptions) { for (auto const& e : exceptions) { try { @@ -39,23 +55,36 @@ void run_example(const sycl::device& gpu_device) { } catch (sycl::exception const& e) { // Handle not dft related exceptions that happened during asynchronous call - std::cerr << "Caught asynchronous SYCL exception:" << std::endl; + std::cerr << "Caught asynchronous SYCL exception on GPU device during execution:" + << std::endl; std::cerr << "\t" << e.what() << std::endl; } } std::exit(2); }; + // Preparation CPU device and GPU device + sycl::queue cpu_queue(cpu_device, cpu_error_handler); sycl::queue gpu_queue(gpu_device, gpu_error_handler); - std::vector> input_data(N); - std::vector> output_data(N); + // allocate on CPU device and GPU device + auto cpu_input_data = sycl::malloc_shared>(N, cpu_queue); + auto cpu_output_data = sycl::malloc_shared>(N, cpu_queue); + + auto gpu_input_data = sycl::malloc_shared>(N, gpu_queue); + auto gpu_output_data = sycl::malloc_shared>(N, gpu_queue); + + // Initialize input data + for (std::size_t i = 0; i < N; ++i) { + cpu_input_data[i] = { static_cast(i), static_cast(-i) }; + gpu_input_data[i] = { static_cast(i), static_cast(-i) }; + } // enabling // 1. create descriptors oneapi::mkl::dft::descriptor - desc(N); + desc(static_cast(N)); // 2. variadic set_value desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, @@ -63,16 +92,27 @@ void run_example(const sycl::device& gpu_device) { desc.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, static_cast(1)); - // 3. commit_descriptor (compile_time MKLGPU) - desc.commit(oneapi::mkl::backend_selector{ gpu_queue }); + // 3a. commit_descriptor (compile_time MKLCPU) + desc.commit(oneapi::mkl::backend_selector{ cpu_queue }); - // 4. compute_forward / compute_backward (MKLGPU) - { - sycl::buffer> input_buffer(input_data.data(), sycl::range<1>(N)); - sycl::buffer> output_buffer(output_data.data(), sycl::range<1>(N)); - oneapi::mkl::dft::compute_forward, std::complex>( - desc, input_buffer, output_buffer); - } + // 4a. compute_forward / compute_backward (MKLCPU) + oneapi::mkl::dft::compute_forward, std::complex>( + desc, cpu_input_data, cpu_output_data); + + // 3b. commit_descriptor (compile_time cuFFT) + desc.commit(oneapi::mkl::backend_selector{ gpu_queue }); + + // 4b. compute_forward / compute_backward (cuFFT) + oneapi::mkl::dft::compute_forward, std::complex>( + desc, gpu_input_data, gpu_output_data); + + cpu_queue.wait_and_throw(); + gpu_queue.wait_and_throw(); + + sycl::free(cpu_input_data, cpu_queue); + sycl::free(gpu_input_data, gpu_queue); + sycl::free(cpu_output_data, cpu_queue); + sycl::free(gpu_output_data, gpu_queue); } // @@ -81,18 +121,16 @@ void run_example(const sycl::device& gpu_device) { void print_example_banner() { std::cout << "\n" "########################################################################\n" - "# Complex out-of-place forward transform for Buffer API's example:\n" + "# Complex out-of-place forward transform for USM API's example:\n" "#\n" "# Using APIs:\n" "# Compile-time dispatch API\n" - "# Buffer forward complex out-of-place\n" + "# USM forward complex out-of-place\n" "#\n" "# Using single precision (float) data type\n" "#\n" - "# For Intel GPU with Intel MKLGPU backend.\n" + "# Running on both Intel CPU and NVIDIA GPU devices.\n" "#\n" - "# The environment variable SYCL_DEVICE_FILTER can be used to specify\n" - "#SYCL device\n" "########################################################################\n" << std::endl; } @@ -100,19 +138,29 @@ void print_example_banner() { // // Main entry point for example. // -int main(int argc, char** argv) { +int main(int /*argc*/, char** /*argv*/) { print_example_banner(); try { + sycl::device cpu_device((sycl::cpu_selector_v)); sycl::device gpu_device((sycl::gpu_selector_v)); - std::cout << "Running DFT Complex forward out-of-place buffer example" << std::endl; - std::cout << "Using compile-time dispatch API with MKLGPU." << std::endl; + + unsigned int vendor_id = gpu_device.get_info(); + if (vendor_id != NVIDIA_ID) { + std::cerr << "FAILED: NVIDIA GPU device not found" << std::endl; + return 1; + } + + std::cout << "Running DFT Complex forward out-of-place usm example" << std::endl; + std::cout << "Using compile-time dispatch API with MKLCPU and cuFFT." << std::endl; std::cout << "Running with single precision real data type on:" << std::endl; + std::cout << "\tCPU device: " << cpu_device.get_info() + << std::endl; std::cout << "\tGPU device :" << gpu_device.get_info() << std::endl; - run_example(gpu_device); - std::cout << "DFT Complex USM example ran OK on MKLGPU" << std::endl; + run_example(cpu_device, gpu_device); + std::cout << "DFT Complex USM example ran OK on MKLCPU and CUFFT" << std::endl; } catch (sycl::exception const& e) { // Handle not dft related exceptions that happened during synchronous call diff --git a/examples/dft/run_time_dispatching/CMakeLists.txt b/examples/dft/run_time_dispatching/CMakeLists.txt index a881a0015..e221c7950 100644 --- a/examples/dft/run_time_dispatching/CMakeLists.txt +++ b/examples/dft/run_time_dispatching/CMakeLists.txt @@ -17,26 +17,37 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== -# NOTE: user needs to set env var SYCL_DEVICE_FILTER to use runtime example (no need to specify backend when building with CMake) +# NOTE: user needs to set env var ONEAPI_DEVICE_SELECTOR to use runtime example (no need to specify backend when building with CMake) +include(WarningsUtils) + # Build object from all example sources set(DFT_RT_SOURCES "") -if(ENABLE_MKLGPU_BACKEND) - list(APPEND DFT_RT_SOURCES "real_fwd_usm") -endif() - # Set up for the right backend for run-time dispatching examples # If users build more than one backend (i.e. mklcpu and mklgpu, or mklcpu and CUDA), they may need to -# overwrite SYCL_DEVICE_FILTER in their environment to run on the desired backend +# overwrite ONEAPI_DEVICE_SELECTOR in their environment to run on the desired backend set(DEVICE_FILTERS "") -if(ENABLE_MKLCPU_BACKEND) - list(APPEND DEVICE_FILTERS "cpu") +if(ENABLE_MKLGPU_BACKEND OR ENABLE_MKLCPU_BACKEND OR ENABLE_CUFFT_BACKEND OR ENABLE_ROCFFT_BACKEND OR ENABLE_PORTFFT_BACKEND) + list(APPEND DFT_RT_SOURCES "real_fwd_usm") endif() + if(ENABLE_MKLGPU_BACKEND) - list(APPEND DEVICE_FILTERS "gpu") + list(APPEND DEVICE_FILTERS "level_zero:gpu") +endif() +if(ENABLE_MKLCPU_BACKEND) + list(APPEND DEVICE_FILTERS "opencl:cpu") +endif() +if(ENABLE_PORTFFT_BACKEND) + list(APPEND DEVICE_FILTERS "*:gpu") +endif() +if(ENABLE_CUFFT_BACKEND) + list(APPEND DEVICE_FILTERS "cuda:gpu") +endif() +if(ENABLE_ROCFFT_BACKEND) + list(APPEND DEVICE_FILTERS "hip:gpu") endif() -message(STATUS "SYCL_DEVICE_FILTER will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") +message(STATUS "ONEAPI_DEVICE_SELECTOR will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") foreach(dft_rt_sources ${DFT_RT_SOURCES}) add_executable(example_${domain}_${dft_rt_sources} ${dft_rt_sources}.cpp) @@ -52,10 +63,11 @@ foreach(dft_rt_sources ${DFT_RT_SOURCES}) add_sycl_to_target(TARGET example_${domain}_${dft_rt_sources} SOURCES ${DFT_RT_SOURCES}) endif() - target_link_libraries(example_${domain}_${dft_rt_sources} PUBLIC - onemkl - ONEMKL::SYCL::SYCL - ${CMAKE_DL_LIBS} + target_link_libraries(example_${domain}_${dft_rt_sources} + PUBLIC onemkl + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC ${CMAKE_DL_LIBS} + PRIVATE onemkl_warnings ) # Register example as ctest @@ -63,7 +75,7 @@ foreach(dft_rt_sources ${DFT_RT_SOURCES}) add_test(NAME ${domain}/EXAMPLE/RT/${dft_rt_sources}/${device_filter} COMMAND example_${domain}_${dft_rt_sources}) set_property(TEST ${domain}/EXAMPLE/RT/${dft_rt_sources}/${device_filter} PROPERTY ENVIRONMENT LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib:$ENV{LD_LIBRARY_PATH} - ENVIRONMENT SYCL_DEVICE_FILTER=${device_filter}) + ENVIRONMENT ONEAPI_DEVICE_SELECTOR=${device_filter}) endforeach(device_filter) endforeach() diff --git a/examples/dft/run_time_dispatching/real_fwd_usm.cpp b/examples/dft/run_time_dispatching/real_fwd_usm.cpp index 1b88ce14d..c220b0ee7 100644 --- a/examples/dft/run_time_dispatching/real_fwd_usm.cpp +++ b/examples/dft/run_time_dispatching/real_fwd_usm.cpp @@ -31,7 +31,7 @@ #include "oneapi/mkl.hpp" void run_example(const sycl::device& dev) { - int N = 16; + constexpr std::size_t N = 16; // Catch asynchronous exceptions auto exception_handler = [](sycl::exception_list exceptions) { @@ -55,10 +55,9 @@ void run_example(const sycl::device& dev) { // 1. create descriptors oneapi::mkl::dft::descriptor - desc(N); + desc(static_cast(N)); // 2. variadic set_value - desc.set_value(oneapi::mkl::dft::config_param::FORWARD_SCALE, 1.f / N); desc.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, static_cast(1)); desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT, @@ -81,9 +80,8 @@ void run_example(const sycl::device& dev) { // Description of example setup, APIs used and supported floating point type precisions // void print_example_banner() { - std::cout << "\n" - "########################################################################\n" - "# DFTI complex in-place forward transform with USM API example:\n" + std::cout << "########################################################################\n" + "# DFT complex in-place forward transform with USM API example:\n" "#\n" "# Using APIs:\n" "# USM forward complex in-place\n" @@ -92,8 +90,8 @@ void print_example_banner() { "# Using single precision (float) data type\n" "#\n" "# Device will be selected during runtime.\n" - "# The environment variable SYCL_DEVICE_FILTER can be used to specify\n" - "# SYCL device\n" + "# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify\n" + "# available devices\n" "#\n" "########################################################################\n" << std::endl; @@ -103,7 +101,7 @@ void print_example_banner() { // Main entry point for example. // -int main(int argc, char** argv) { +int main(int /*argc*/, char** /*argv*/) { print_example_banner(); try { @@ -124,6 +122,11 @@ int main(int argc, char** argv) { run_example(my_dev); std::cout << "DFT example ran OK" << std::endl; } + catch (oneapi::mkl::unimplemented const& e) { + std::cerr << "Unsupported Configuration:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + return 0; + } catch (sycl::exception const& e) { std::cerr << "Caught synchronous SYCL exception:" << std::endl; std::cerr << "\t" << e.what() << std::endl; diff --git a/examples/include/example_helper.hpp b/examples/include/example_helper.hpp index 4f73f8971..4a89e6fae 100644 --- a/examples/include/example_helper.hpp +++ b/examples/include/example_helper.hpp @@ -20,14 +20,50 @@ #ifndef __EXAMPLE_HELPER_HPP__ #define __EXAMPLE_HELPER_HPP__ +#if __has_include() +#include +#else +#include +#endif + +#include +#include +#include +#include +#include + +// Complex helpers. +template +struct complex_info { + using real_type = T; + static const bool is_complex = false; +}; + +template +struct complex_info> { + using real_type = T; + static const bool is_complex = true; +}; + +template +struct is_complex : std::false_type {}; +template +struct is_complex> : std::true_type {}; + // // helpers for initializing templated scalar data type values. // template -fp set_fp_value(fp arg1, fp arg2 = 0.0) { +fp set_fp_value(fp arg1, fp /*arg2*/ = fp(0.0)) { return arg1; } +template +std::complex set_fp_value(std::complex arg1, + std::complex arg2 = std::complex(0.0)) { + return std::complex(arg1.real(), arg2.real()); +} + // // print a 2x2 block of data from matrix M using the sycl accessor // @@ -67,4 +103,80 @@ void rand_matrix(vec &M, oneapi::mkl::transpose trans, int m, int n, int ld) { } } +template +intType generate_sparse_matrix(const intType nx, intType *ia, intType *ja, fp *a, + const intType index = 0) { + intType nz = nx, ny = nx; + intType nnz = 0; + intType current_row; + + ia[0] = index; + + for (intType iz = 0; iz < nz; iz++) { + for (intType iy = 0; iy < ny; iy++) { + for (intType ix = 0; ix < nx; ix++) { + current_row = iz * nx * ny + iy * nx + ix; + + for (intType sz = -1; sz <= 1; sz++) { + if (iz + sz > -1 && iz + sz < nz) { + for (intType sy = -1; sy <= 1; sy++) { + if (iy + sy > -1 && iy + sy < ny) { + for (intType sx = -1; sx <= 1; sx++) { + if (ix + sx > -1 && ix + sx < nx) { + intType current_column = + current_row + sz * nx * ny + sy * nx + sx; + ja[nnz] = current_column + index; + if (current_column == current_row) { + a[nnz++] = set_fp_value(fp(26.0)); + } + else { + a[nnz++] = set_fp_value(fp(-1.0)); + } + } // end + // x + // bounds + // test + } // end sx loop + } // end y bounds test + } // end sy loop + } // end z bounds test + } // end sz loop + ia[current_row + 1] = nnz + index; + + } // end ix loop + } // end iy loop + } // end iz loop + return nnz; +} + +template +bool check_errors(fp x, fp x_ref, fp_real bound) { + fp_real aerr = std::abs(x - x_ref); + fp_real rerr = aerr / (std::abs(x_ref) + std::numeric_limits::epsilon()); + bool ok = (rerr <= bound) || (aerr <= bound); + if (!ok) + std::cout << "relative error = " << rerr << " absolute error = " << aerr + << " limit = " << bound; + return ok; +} + +template +bool check_result(fp res, fp ref, intType nFlops, intType index) { + bool check; + using fp_real = typename complex_info::real_type; + fp_real bound = std::numeric_limits::epsilon() * static_cast(nFlops); + check = check_errors(res, ref, bound); + if (!check) + std::cout << " in index: " << index << std::endl; + return check; +} + +template +void free_vec(std::vector &ptr_vec, sycl::queue queue) { + for (auto ptr : ptr_vec) { + sycl::free(ptr, queue); + } + ptr_vec.clear(); +} + #endif //__EXAMPLE_HELPER_HPP__ diff --git a/examples/lapack/run_time_dispatching/CMakeLists.txt b/examples/lapack/run_time_dispatching/CMakeLists.txt index 6d6250c2b..5fcf6a311 100644 --- a/examples/lapack/run_time_dispatching/CMakeLists.txt +++ b/examples/lapack/run_time_dispatching/CMakeLists.txt @@ -17,20 +17,20 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== -# NOTE: user needs to set env var SYCL_DEVICE_FILTER to use runtime example without specifying backend in CMake +# NOTE: user needs to set env var ONEAPI_DEVICE_SELECTOR to use runtime example without specifying backend in CMake # Build object from all example sources set(LAPACK_RT_SOURCES "getrs_usm") # Set up for the right backend for run-time dispatching examples # If users build more than one backend (i.e. mklcpu and mklgpu, or mklcpu and CUDA), they may need to -# overwrite SYCL_DEVICE_FILTER in their environment to run on the desired backend +# overwrite ONEAPI_DEVICE_SELECTOR in their environment to run on the desired backend set(DEVICE_FILTERS "") if(ENABLE_MKLCPU_BACKEND) - list(APPEND DEVICE_FILTERS "cpu") + list(APPEND DEVICE_FILTERS "opencl:cpu") endif() if(ENABLE_MKLGPU_BACKEND) - list(APPEND DEVICE_FILTERS "gpu") + list(APPEND DEVICE_FILTERS "level_zero:gpu") endif() if(ENABLE_CUSOLVER_BACKEND) list(APPEND DEVICE_FILTERS "cuda:gpu") @@ -39,7 +39,7 @@ if(ENABLE_ROCSOLVER_BACKEND) list(APPEND DEVICE_FILTERS "hip:gpu") endif() -message(STATUS "SYCL_DEVICE_FILTER will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") +message(STATUS "ONEAPI_DEVICE_SELECTOR will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") foreach(lapack_rt_source ${LAPACK_RT_SOURCES}) add_executable(example_${domain}_${lapack_rt_source} ${lapack_rt_source}.cpp) @@ -66,7 +66,7 @@ foreach(lapack_rt_source ${LAPACK_RT_SOURCES}) add_test(NAME ${domain}/EXAMPLE/RT/${lapack_rt_source}/${device_filter} COMMAND example_${domain}_${lapack_rt_source}) set_property(TEST ${domain}/EXAMPLE/RT/${lapack_rt_source}/${device_filter} PROPERTY ENVIRONMENT LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib:$ENV{LD_LIBRARY_PATH} - ENVIRONMENT SYCL_DEVICE_FILTER=${device_filter}) + ENVIRONMENT ONEAPI_DEVICE_SELECTOR=${device_filter}) endforeach(device_filter) endforeach(lapack_rt_source) diff --git a/examples/lapack/run_time_dispatching/getrs_usm.cpp b/examples/lapack/run_time_dispatching/getrs_usm.cpp index f72e68c01..4cf851a7e 100644 --- a/examples/lapack/run_time_dispatching/getrs_usm.cpp +++ b/examples/lapack/run_time_dispatching/getrs_usm.cpp @@ -203,9 +203,9 @@ void print_example_banner() { std::cout << "# Using single precision (float) data type" << std::endl; std::cout << "# " << std::endl; std::cout << "# Device will be selected during runtime." << std::endl; - std::cout << "# The environment variable SYCL_DEVICE_FILTER can be used to specify" + std::cout << "# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify" << std::endl; - std::cout << "# SYCL device" << std::endl; + std::cout << "# available devices" << std::endl; std::cout << "# " << std::endl; std::cout << "########################################################################" << std::endl; diff --git a/examples/rng/CMakeLists.txt b/examples/rng/CMakeLists.txt index bd9c159f3..b2890bf19 100644 --- a/examples/rng/CMakeLists.txt +++ b/examples/rng/CMakeLists.txt @@ -20,6 +20,7 @@ # Note: compile-time example uses both MKLCPU and CURAND backends, therefore # cmake in the sub-directory will only build it if CURAND backend is enabled add_subdirectory(compile_time_dispatching) +add_subdirectory(device) # runtime compilation is only possible with dynamic libraries if (BUILD_SHARED_LIBS) diff --git a/examples/rng/device/CMakeLists.txt b/examples/rng/device/CMakeLists.txt new file mode 100644 index 000000000..1b6ecf2dd --- /dev/null +++ b/examples/rng/device/CMakeLists.txt @@ -0,0 +1,74 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +# NOTE: user needs to set env var ONEAPI_DEVICE_SELECTOR to use runtime example (no need to specify backend when building with CMake) + +# Build object from all example sources +set(RNG_DEVICE_SOURCES "uniform") + +# Set up for the right backend for run-time dispatching examples +# If users build more than one backend (i.e. mklcpu and mklgpu, or mklcpu and CUDA), they may need to +# overwrite ONEAPI_DEVICE_SELECTOR in their environment to run on the desired backend +set(DEVICE_FILTERS "") +if(ENABLE_MKLCPU_BACKEND) + list(APPEND DEVICE_FILTERS "opencl:cpu") +endif() +# RNG only supports mklcpu backend on Windows +if(ENABLE_MKLGPU_BACKEND) + list(APPEND DEVICE_FILTERS "level_zero:gpu") +endif() +if(ENABLE_CURAND_BACKEND) + list(APPEND DEVICE_FILTERS "cuda:gpu") +endif() +if(ENABLE_ROCRAND_BACKEND) + list(APPEND DEVICE_FILTERS "hip:gpu") +endif() + +message(STATUS "ONEAPI_DEVICE_SELECTOR will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") + +foreach(rng_device_source ${RNG_DEVICE_SOURCES}) + add_executable(example_${domain}_${rng_device_source} ${rng_device_source}.cpp) + target_include_directories(example_${domain}_${rng_device_source} + PUBLIC ${PROJECT_SOURCE_DIR}/examples/rng/device/include + PUBLIC ${PROJECT_SOURCE_DIR}/examples/include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + ) + + if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET example_${domain}_${rng_device_source} SOURCES ${RNG_DEVICE_SOURCES}) + endif() + + target_link_libraries(example_${domain}_${rng_device_source} PUBLIC + ONEMKL::SYCL::SYCL + ) + + if(NOT ${ONEMKL_SYCL_IMPLEMENTATION} STREQUAL "hipsycl") + target_link_options(example_${domain}_${rng_device_source} PUBLIC -fsycl -fsycl-device-code-split=per_kernel) + endif() + + # Register example as ctest + foreach(device_filter ${DEVICE_FILTERS}) + add_test(NAME ${domain}/EXAMPLE/DEVICE/${rng_device_source}/${device_filter} COMMAND example_${domain}_${rng_device_source}) + set_property(TEST ${domain}/EXAMPLE/DEVICE/${rng_device_source}/${device_filter} PROPERTY + ENVIRONMENT LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib:$ENV{LD_LIBRARY_PATH} + ENVIRONMENT ONEAPI_DEVICE_SELECTOR=${device_filter}) + endforeach(device_filter) + +endforeach() diff --git a/examples/rng/device/include/rng_example_helper.hpp b/examples/rng/device/include/rng_example_helper.hpp new file mode 100644 index 000000000..0bcf114b4 --- /dev/null +++ b/examples/rng/device/include/rng_example_helper.hpp @@ -0,0 +1,50 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _RNG_EXAMPLE_HELPER_HPP__ +#define _RNG_EXAMPLE_HELPER_HPP__ + +template +struct has_member_code_meta : std::false_type {}; + +template +struct has_member_code_meta().get_multi_ptr())>> + : std::true_type {}; + +template ::value>::type* = nullptr> +auto get_multi_ptr(T acc) { +// Workaround for AdaptiveCPP, as they do not yet support the get_multi_ptr function +#ifndef __HIPSYCL__ + return acc.get_multi_ptr(); +#else + return acc.get_pointer(); +#endif +}; + +template ::value>::type* = nullptr> +auto get_multi_ptr(T acc) { +// Workaround for AdaptiveCPP, as they do not yet support the get_multi_ptr function +#ifndef __HIPSYCL__ + return acc.template get_multi_ptr(); +#else + return acc.get_pointer(); +#endif +}; + +#endif // _RNG_EXAMPLE_HELPER_HPP__ diff --git a/examples/rng/device/uniform.cpp b/examples/rng/device/uniform.cpp new file mode 100644 index 000000000..a1c097bba --- /dev/null +++ b/examples/rng/device/uniform.cpp @@ -0,0 +1,213 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +/* +* +* Content: +* This example demonstrates usage of oneapi::mkl::rng::device::mcg59 +* random number generator to produce random +* numbers using unifrom distribution on a SYCL device (CPU, GPU). +* +*******************************************************************************/ + +// stl includes +#include +#include + +// oneMKL/SYCL includes +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/rng/device.hpp" + +#include "rng_example_helper.hpp" + +bool isDoubleSupported(sycl::device my_dev) { + return my_dev.get_info().size() != 0; +} + +// example parameters +constexpr std::uint64_t seed = 777; +constexpr std::size_t n = 1024; +constexpr int n_print = 10; + +// +// example show usage of rng device functionality, which can be called from both +// host and device sides with scalar and vector generation +// +template +int run_example(sycl::queue& queue) { + if (VecSize == 1) { + std::cout << "\tRunning scalar example" << std::endl; + } + else { + std::cout << "\tRunning vector example with " << VecSize << " vector size" << std::endl; + } + // prepare array for random numbers + std::vector r_dev(n); + + // submit a kernel to generate on device + { + sycl::buffer r_buf(r_dev.data(), r_dev.size()); + + try { + queue.submit([&](sycl::handler& cgh) { + sycl::accessor r_acc(r_buf, cgh, sycl::write_only); + cgh.parallel_for(sycl::range<1>(n / VecSize), [=](sycl::item<1> item) { + size_t item_id = item.get_id(0); + oneapi::mkl::rng::device::mcg59 engine(seed, item_id * VecSize); + oneapi::mkl::rng::device::uniform distr; + + auto res = oneapi::mkl::rng::device::generate(distr, engine); + if constexpr (VecSize == 1) { + r_acc[item_id] = res; + } + else { + res.store(item_id, get_multi_ptr(r_acc)); + } + }); + }); + queue.wait_and_throw(); + } + catch (sycl::exception const& e) { + std::cout << "\t\tSYCL exception\n" << e.what() << std::endl; + return 1; + } + + std::cout << "\t\tOutput of generator:" << std::endl; + + auto r_acc = sycl::host_accessor(r_buf, sycl::read_only); + std::cout << "first " << n_print << " numbers of " << n << ": " << std::endl; + for (int i = 0; i < n_print; i++) { + std::cout << r_acc[i] << " "; + } + std::cout << std::endl; + } // buffer life-time ends + + // compare results with host-side generation + oneapi::mkl::rng::device::mcg59<1> engine(seed); + oneapi::mkl::rng::device::uniform distr; + + int err = 0; + Type res_host; + for (int i = 0; i < n; i++) { + res_host = oneapi::mkl::rng::device::generate(distr, engine); + if (res_host != r_dev[i]) { + std::cout << "error in " << i << " element " << res_host << " " << r_dev[i] + << std::endl; + err++; + } + } + return err; +} + +// +// description of example setup, APIs used +// +void print_example_banner() { + std::cout << "" << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout << "# Generate uniformly distributed random numbers example: " << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Using APIs:" << std::endl; + std::cout << "# mcg59 uniform" << std::endl; + std::cout << "# " << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout << std::endl; +} + +int main() { + // Catch asynchronous exceptions + auto exception_handler = [](sycl::exception_list exceptions) { + for (std::exception_ptr const& e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (sycl::exception const& e) { + std::cerr << "Caught asynchronous SYCL exception during generation:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + } + } + std::exit(2); + }; + + print_example_banner(); + + try { + sycl::device my_dev = sycl::device(); + + if (my_dev.is_gpu()) { + std::cout << "Running RNG uniform usm example on GPU device" << std::endl; + std::cout << "Device name is: " << my_dev.get_info() + << std::endl; + } + else { + std::cout << "Running RNG uniform usm example on CPU device" << std::endl; + std::cout << "Device name is: " << my_dev.get_info() + << std::endl; + } + + sycl::queue queue(my_dev, exception_handler); + + std::cout << "\n\tRunning with single precision real data type:" << std::endl; + if (run_example(queue) || run_example(queue)) { + std::cout << "FAILED" << std::endl; + return 1; + } + if (isDoubleSupported(my_dev)) { + std::cout << "\n\tRunning with double precision real data type:" << std::endl; + if (run_example(queue) || run_example(queue)) { + std::cout << "FAILED" << std::endl; + return 1; + } + } + else { + std::cout << "Double precision is not supported for this device" << std::endl; + } + std::cout << "\n\tRunning with integer data type:" << std::endl; + if (run_example(queue) || run_example(queue)) { + std::cout << "FAILED" << std::endl; + return 1; + } + std::cout << "\n\tRunning with unsigned integer data type:" << std::endl; + if (run_example(queue) || run_example(queue)) { + std::cout << "FAILED" << std::endl; + return 1; + } + + std::cout << "Random number generator with uniform distribution ran OK" << std::endl; + } + catch (sycl::exception const& e) { + std::cerr << "Caught synchronous SYCL exception:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + std::cerr << "\tSYCL error code: " << e.code().value() << std::endl; + return 1; + } + catch (std::exception const& e) { + std::cerr << "Caught std::exception during generation:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + return 1; + } + return 0; +} diff --git a/examples/rng/run_time_dispatching/CMakeLists.txt b/examples/rng/run_time_dispatching/CMakeLists.txt index 5c0392d64..d3bcc0f19 100644 --- a/examples/rng/run_time_dispatching/CMakeLists.txt +++ b/examples/rng/run_time_dispatching/CMakeLists.txt @@ -17,21 +17,21 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== -# NOTE: user needs to set env var SYCL_DEVICE_FILTER to use runtime example (no need to specify backend when building with CMake) +# NOTE: user needs to set env var ONEAPI_DEVICE_SELECTOR to use runtime example (no need to specify backend when building with CMake) # Build object from all example sources set(RNG_RT_SOURCES "uniform_usm") # Set up for the right backend for run-time dispatching examples # If users build more than one backend (i.e. mklcpu and mklgpu, or mklcpu and CUDA), they may need to -# overwrite SYCL_DEVICE_FILTER in their environment to run on the desired backend +# overwrite ONEAPI_DEVICE_SELECTOR in their environment to run on the desired backend set(DEVICE_FILTERS "") if(ENABLE_MKLCPU_BACKEND) - list(APPEND DEVICE_FILTERS "cpu") + list(APPEND DEVICE_FILTERS "opencl:cpu") endif() # RNG only supports mklcpu backend on Windows if(UNIX AND ENABLE_MKLGPU_BACKEND) - list(APPEND DEVICE_FILTERS "gpu") + list(APPEND DEVICE_FILTERS "level_zero:gpu") endif() if(UNIX AND ENABLE_CURAND_BACKEND) list(APPEND DEVICE_FILTERS "cuda:gpu") @@ -40,7 +40,7 @@ if(UNIX AND ENABLE_ROCRAND_BACKEND) list(APPEND DEVICE_FILTERS "hip:gpu") endif() -message(STATUS "SYCL_DEVICE_FILTER will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") +message(STATUS "ONEAPI_DEVICE_SELECTOR will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") foreach(rng_rt_source ${RNG_RT_SOURCES}) add_executable(example_${domain}_${rng_rt_source} ${rng_rt_source}.cpp) @@ -67,7 +67,7 @@ foreach(rng_rt_source ${RNG_RT_SOURCES}) add_test(NAME ${domain}/EXAMPLE/RT/${rng_rt_source}/${device_filter} COMMAND example_${domain}_${rng_rt_source}) set_property(TEST ${domain}/EXAMPLE/RT/${rng_rt_source}/${device_filter} PROPERTY ENVIRONMENT LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib:$ENV{LD_LIBRARY_PATH} - ENVIRONMENT SYCL_DEVICE_FILTER=${device_filter}) + ENVIRONMENT ONEAPI_DEVICE_SELECTOR=${device_filter}) endforeach(device_filter) endforeach() diff --git a/examples/rng/run_time_dispatching/uniform_usm.cpp b/examples/rng/run_time_dispatching/uniform_usm.cpp index 62a726d47..8ac7363c8 100644 --- a/examples/rng/run_time_dispatching/uniform_usm.cpp +++ b/examples/rng/run_time_dispatching/uniform_usm.cpp @@ -141,9 +141,9 @@ void print_example_banner() { std::cout << "# Using single precision (float) data type" << std::endl; std::cout << "# " << std::endl; std::cout << "# Device will be selected during runtime." << std::endl; - std::cout << "# The environment variable SYCL_DEVICE_FILTER can be used to specify" + std::cout << "# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify" << std::endl; - std::cout << "# SYCL device" << std::endl; + std::cout << "# available devices" << std::endl; std::cout << "# " << std::endl; std::cout << "########################################################################" << std::endl; diff --git a/examples/sparse_blas/CMakeLists.txt b/examples/sparse_blas/CMakeLists.txt new file mode 100644 index 000000000..721512429 --- /dev/null +++ b/examples/sparse_blas/CMakeLists.txt @@ -0,0 +1,25 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +add_subdirectory(compile_time_dispatching) + +# runtime compilation is only possible with dynamic libraries +if (BUILD_SHARED_LIBS) + add_subdirectory(run_time_dispatching) +endif() diff --git a/examples/sparse_blas/compile_time_dispatching/CMakeLists.txt b/examples/sparse_blas/compile_time_dispatching/CMakeLists.txt new file mode 100644 index 000000000..cb95333b4 --- /dev/null +++ b/examples/sparse_blas/compile_time_dispatching/CMakeLists.txt @@ -0,0 +1,44 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +#Build object from all sources +set(SPARSE_BLAS_BACKENDS "") + +if(ENABLE_MKLCPU_BACKEND) + list(APPEND SPARSE_BLAS_BACKENDS "mklcpu") +endif() + +include(WarningsUtils) + +foreach(backend ${SPARSE_BLAS_BACKENDS}) + set(EXAMPLE_NAME example_sparse_blas_gemv_usm_${backend}) + add_executable(${EXAMPLE_NAME} sparse_blas_gemv_usm_${backend}.cpp) + target_include_directories(${EXAMPLE_NAME} + PUBLIC ${PROJECT_SOURCE_DIR}/examples/include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + ) + + add_dependencies(${EXAMPLE_NAME} onemkl_sparse_blas_${backend}) + target_link_libraries(${EXAMPLE_NAME} PRIVATE ONEMKL::SYCL::SYCL onemkl_sparse_blas_${backend}) + + # Register example as ctest + add_test(NAME sparse_blas/EXAMPLE/CT/sparse_blas_gemv_usm_${backend} COMMAND ${EXAMPLE_NAME}) +endforeach(backend) + diff --git a/examples/sparse_blas/compile_time_dispatching/sparse_blas_gemv_usm_mklcpu.cpp b/examples/sparse_blas/compile_time_dispatching/sparse_blas_gemv_usm_mklcpu.cpp new file mode 100644 index 000000000..edb6d7e1f --- /dev/null +++ b/examples/sparse_blas/compile_time_dispatching/sparse_blas_gemv_usm_mklcpu.cpp @@ -0,0 +1,256 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +/* +* +* Content: +* This example demonstrates use of DPCPP API oneapi::mkl::sparse::gemv +* using unified shared memory to perform general sparse matrix-vector +* multiplication on a INTEL CPU SYCL device. +* +* y = alpha * op(A) * x + beta * y +* +* where op() is defined by one of +* +* oneapi::mkl::transpose::{nontrans,trans,conjtrans} +* +* +* This example demonstrates only single precision (float) data type for +* gemv matrix data +* +* +*******************************************************************************/ + +// stl includes +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "oneapi/mkl.hpp" + +#include "example_helper.hpp" + +// +// Main example for Sparse Matrix-Vector Multiply consisting of +// initialization of A matrix, x and y vectors as well as +// scalars alpha and beta. Then the product +// +// y = alpha * op(A) * x + beta * y +// +// is performed and finally the results are post processed. +// +template +int run_sparse_matrix_vector_multiply_example(const sycl::device &cpu_dev) { + // Matrix data size + intType size = 4; + intType nrows = size * size * size; + + // Set scalar fp values + fp alpha = set_fp_value(fp(1.0)); + fp beta = set_fp_value(fp(0.0)); + + // Catch asynchronous exceptions + auto exception_handler = [](sycl::exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (sycl::exception const &e) { + std::cout << "Caught asynchronous SYCL " + "exception during sparse::gemv:\n" + << e.what() << std::endl; + } + } + }; + + // create execution queue and buffers of matrix data + sycl::queue cpu_queue(cpu_dev, exception_handler); + oneapi::mkl::backend_selector cpu_selector{ cpu_queue }; + + intType *ia, *ja; + fp *a, *x, *y, *z; + std::size_t sizea = static_cast(27 * nrows); + std::size_t sizeja = static_cast(27 * nrows); + std::size_t sizeia = static_cast(nrows + 1); + std::size_t sizevec = static_cast(nrows); + + ia = (intType *)sycl::malloc_shared(sizeia * sizeof(intType), cpu_queue); + ja = (intType *)sycl::malloc_shared(sizeja * sizeof(intType), cpu_queue); + a = (fp *)sycl::malloc_shared(sizea * sizeof(fp), cpu_queue); + x = (fp *)sycl::malloc_shared(sizevec * sizeof(fp), cpu_queue); + y = (fp *)sycl::malloc_shared(sizevec * sizeof(fp), cpu_queue); + z = (fp *)sycl::malloc_shared(sizevec * sizeof(fp), cpu_queue); + + if (!ia || !ja || !a || !x || !y || !z) { + throw std::runtime_error("Failed to allocate USM memory"); + } + + intType nnz = generate_sparse_matrix(size, ia, ja, a); + + // Init vectors x and y + for (int i = 0; i < nrows; i++) { + x[i] = set_fp_value(fp(1.0)); + y[i] = set_fp_value(fp(0.0)); + z[i] = set_fp_value(fp(0.0)); + } + + std::vector int_ptr_vec; + int_ptr_vec.push_back(ia); + int_ptr_vec.push_back(ja); + std::vector fp_ptr_vec; + fp_ptr_vec.push_back(a); + fp_ptr_vec.push_back(x); + fp_ptr_vec.push_back(y); + fp_ptr_vec.push_back(z); + + // + // Execute Matrix Multiply + // + + oneapi::mkl::transpose transA = oneapi::mkl::transpose::nontrans; + std::cout << "\n\t\tsparse::gemv parameters:\n"; + std::cout << "\t\t\ttransA = " + << (transA == oneapi::mkl::transpose::nontrans + ? "nontrans" + : (transA == oneapi::mkl::transpose::trans ? "trans" : "conjtrans")) + << std::endl; + std::cout << "\t\t\tnrows = " << nrows << std::endl; + std::cout << "\t\t\talpha = " << alpha << ", beta = " << beta << std::endl; + + // create and initialize handle for a Sparse Matrix in CSR format + oneapi::mkl::sparse::matrix_handle_t handle = nullptr; + + oneapi::mkl::sparse::init_matrix_handle(cpu_selector, &handle); + + auto ev_set = oneapi::mkl::sparse::set_csr_data(cpu_selector, handle, nrows, nrows, nnz, + oneapi::mkl::index_base::zero, ia, ja, a); + + auto ev_opt = oneapi::mkl::sparse::optimize_gemv(cpu_selector, transA, handle, { ev_set }); + + auto ev_gemv = + oneapi::mkl::sparse::gemv(cpu_selector, transA, alpha, handle, x, beta, y, { ev_opt }); + + auto ev_release = + oneapi::mkl::sparse::release_matrix_handle(cpu_selector, &handle, { ev_gemv }); + + ev_release.wait_and_throw(); + + // + // Post Processing + // + + fp *res = y; + const bool isConj = (transA == oneapi::mkl::transpose::conjtrans); + for (intType row = 0; row < nrows; row++) { + z[row] *= beta; + } + for (intType row = 0; row < nrows; row++) { + fp tmp = alpha * x[row]; + for (intType i = ia[row]; i < ia[row + 1]; i++) { + if constexpr (is_complex()) { + z[ja[i]] += tmp * (isConj ? std::conj(a[i]) : a[i]); + } + else { + z[ja[i]] += tmp * a[i]; + } + } + } + + bool good = true; + for (intType row = 0; row < nrows; row++) { + good &= check_result(res[row], z[row], nrows, row); + } + + std::cout << "\n\t\t sparse::gemv example " << (good ? "passed" : "failed") << "\n\tFinished" + << std::endl; + + free_vec(fp_ptr_vec, cpu_queue); + free_vec(int_ptr_vec, cpu_queue); + + if (!good) + return 1; + + return 0; +} + +// +// Description of example setup, apis used and supported floating point type +// precisions +// +void print_example_banner() { + std::cout << "" << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout << "# Sparse Matrix-Vector Multiply Example: " << std::endl; + std::cout << "# " << std::endl; + std::cout << "# y = alpha * op(A) * x + beta * y" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# where A is a sparse matrix in CSR format, x and y are " + "dense vectors" + << std::endl; + std::cout << "# and alpha, beta are floating point type precision scalars." << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Using apis:" << std::endl; + std::cout << "# sparse::gemv" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Using single precision (float) data type" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Running on Intel CPU device" << std::endl; + std::cout << "# " << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout << std::endl; +} + +// +// Main entry point for example +// +int main(int /*argc*/, char ** /*argv*/) { + print_example_banner(); + + try { + // TODO: Add cuSPARSE compile-time dispatcher in this example once it is supported. + sycl::device cpu_dev(sycl::cpu_selector_v); + + std::cout << "Running Sparse BLAS GEMV USM example on CPU device." << std::endl; + std::cout << "Device name is: " << cpu_dev.get_info() + << std::endl; + std::cout << "Running with single precision real data type:" << std::endl; + + run_sparse_matrix_vector_multiply_example(cpu_dev); + std::cout << "Sparse BLAS GEMV USM example ran OK." << std::endl; + } + catch (sycl::exception const &e) { + std::cerr << "Caught synchronous SYCL exception during Sparse GEMV:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + std::cerr << "\tSYCL error code: " << e.code().value() << std::endl; + return 1; + } + catch (std::exception const &e) { + std::cerr << "Caught std::exception during Sparse GEMV:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/examples/sparse_blas/run_time_dispatching/CMakeLists.txt b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt new file mode 100644 index 000000000..6f144c898 --- /dev/null +++ b/examples/sparse_blas/run_time_dispatching/CMakeLists.txt @@ -0,0 +1,68 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +# NOTE: user needs to set env var ONEAPI_DEVICE_SELECTOR to use runtime example (no need to specify backend when building with CMake) + +include(WarningsUtils) + +# Build object from all example sources +set(SPARSE_BLAS_RT_SOURCES "sparse_blas_gemv_usm") +# Set up for the right backend for run-time dispatching examples +# If users build more than one backend (i.e. mklcpu and mklgpu, or mklcpu and CUDA), they may need to +# overwrite ONEAPI_DEVICE_SELECTOR in their environment to run on the desired backend +set(DEVICE_FILTERS "") +if(ENABLE_MKLCPU_BACKEND) + list(APPEND DEVICE_FILTERS "opencl:cpu") +endif() +if(ENABLE_MKLGPU_BACKEND) + list(APPEND DEVICE_FILTERS "level_zero:gpu") +endif() + +message(STATUS "ONEAPI_DEVICE_SELECTOR will be set to the following value(s): [${DEVICE_FILTERS}] for run-time dispatching examples") + +foreach(sparse_blas_rt_sources ${SPARSE_BLAS_RT_SOURCES}) + add_executable(example_${sparse_blas_rt_sources} ${sparse_blas_rt_sources}.cpp) + target_include_directories(example_${sparse_blas_rt_sources} + PUBLIC ${PROJECT_SOURCE_DIR}/examples/include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + ) + + add_dependencies(example_${sparse_blas_rt_sources} onemkl) + + if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET example_${sparse_blas_rt_sources} SOURCES ${SPARSE_BLAS_RT_SOURCES}) + endif() + + target_link_libraries(example_${sparse_blas_rt_sources} + PUBLIC onemkl + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC ${CMAKE_DL_LIBS} + PRIVATE onemkl_warnings + ) + + # Register example as ctest + foreach(device_filter ${DEVICE_FILTERS}) + add_test(NAME ${domain}/EXAMPLE/RT/${sparse_blas_rt_sources}/${device_filter} COMMAND example_${sparse_blas_rt_sources}) + set_property(TEST ${domain}/EXAMPLE/RT/${sparse_blas_rt_sources}/${device_filter} PROPERTY + ENVIRONMENT LD_LIBRARY_PATH=${CMAKE_BINARY_DIR}/lib:$ENV{LD_LIBRARY_PATH} + ENVIRONMENT ONEAPI_DEVICE_SELECTOR=${device_filter}) + endforeach(device_filter) + +endforeach() diff --git a/examples/sparse_blas/run_time_dispatching/sparse_blas_gemv_usm.cpp b/examples/sparse_blas/run_time_dispatching/sparse_blas_gemv_usm.cpp new file mode 100644 index 000000000..b5812fabf --- /dev/null +++ b/examples/sparse_blas/run_time_dispatching/sparse_blas_gemv_usm.cpp @@ -0,0 +1,264 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +/* +* +* Content: +* This example demonstrates use of DPCPP API oneapi::mkl::sparse::gemv +* using unified shared memory to perform general sparse matrix-vector +* multiplication on a SYCL device (HOST, CPU, GPU) that is selected +* during runtime. +* +* y = alpha * op(A) * x + beta * y +* +* where op() is defined by one of +* +* oneapi::mkl::transpose::{nontrans,trans,conjtrans} +* +* +* This example demonstrates only single precision (float) data type for +* gemv matrix data +* +* +*******************************************************************************/ + +// stl includes +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "oneapi/mkl.hpp" + +#include "example_helper.hpp" + +// +// Main example for Sparse Matrix-Vector Multiply consisting of +// initialization of A matrix, x and y vectors as well as +// scalars alpha and beta. Then the product +// +// y = alpha * op(A) * x + beta * y +// +// is performed and finally the results are post processed. +// +template +int run_sparse_matrix_vector_multiply_example(const sycl::device &dev) { + // Matrix data size + intType size = 4; + intType nrows = size * size * size; + + // Set scalar fp values + fp alpha = set_fp_value(fp(1.0)); + fp beta = set_fp_value(fp(0.0)); + + // Catch asynchronous exceptions + auto exception_handler = [](sycl::exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (sycl::exception const &e) { + std::cout << "Caught asynchronous SYCL " + "exception during sparse::gemv:\n" + << e.what() << std::endl; + } + } + }; + + // create execution queue and buffers of matrix data + sycl::queue main_queue(dev, exception_handler); + + intType *ia, *ja; + fp *a, *x, *y, *z; + std::size_t sizea = static_cast(27 * nrows); + std::size_t sizeja = static_cast(27 * nrows); + std::size_t sizeia = static_cast(nrows + 1); + std::size_t sizevec = static_cast(nrows); + + ia = (intType *)sycl::malloc_shared(sizeia * sizeof(intType), main_queue); + ja = (intType *)sycl::malloc_shared(sizeja * sizeof(intType), main_queue); + a = (fp *)sycl::malloc_shared(sizea * sizeof(fp), main_queue); + x = (fp *)sycl::malloc_shared(sizevec * sizeof(fp), main_queue); + y = (fp *)sycl::malloc_shared(sizevec * sizeof(fp), main_queue); + z = (fp *)sycl::malloc_shared(sizevec * sizeof(fp), main_queue); + + if (!ia || !ja || !a || !x || !y || !z) { + throw std::runtime_error("Failed to allocate USM memory"); + } + + intType nnz = generate_sparse_matrix(size, ia, ja, a); + + // Init vectors x and y + for (int i = 0; i < nrows; i++) { + x[i] = set_fp_value(fp(1.0)); + y[i] = set_fp_value(fp(0.0)); + z[i] = set_fp_value(fp(0.0)); + } + + std::vector int_ptr_vec; + int_ptr_vec.push_back(ia); + int_ptr_vec.push_back(ja); + std::vector fp_ptr_vec; + fp_ptr_vec.push_back(a); + fp_ptr_vec.push_back(x); + fp_ptr_vec.push_back(y); + fp_ptr_vec.push_back(z); + + // + // Execute Matrix Multiply + // + + oneapi::mkl::transpose transA = oneapi::mkl::transpose::nontrans; + std::cout << "\n\t\tsparse::gemv parameters:\n"; + std::cout << "\t\t\ttransA = " + << (transA == oneapi::mkl::transpose::nontrans + ? "nontrans" + : (transA == oneapi::mkl::transpose::trans ? "trans" : "conjtrans")) + << std::endl; + std::cout << "\t\t\tnrows = " << nrows << std::endl; + std::cout << "\t\t\talpha = " << alpha << ", beta = " << beta << std::endl; + + // create and initialize handle for a Sparse Matrix in CSR format + oneapi::mkl::sparse::matrix_handle_t handle = nullptr; + + oneapi::mkl::sparse::init_matrix_handle(main_queue, &handle); + + auto ev_set = oneapi::mkl::sparse::set_csr_data(main_queue, handle, nrows, nrows, nnz, + oneapi::mkl::index_base::zero, ia, ja, a); + + auto ev_opt = oneapi::mkl::sparse::optimize_gemv(main_queue, transA, handle, { ev_set }); + + auto ev_gemv = + oneapi::mkl::sparse::gemv(main_queue, transA, alpha, handle, x, beta, y, { ev_opt }); + + auto ev_release = oneapi::mkl::sparse::release_matrix_handle(main_queue, &handle, { ev_gemv }); + + ev_release.wait_and_throw(); + + // + // Post Processing + // + + fp *res = y; + const bool isConj = (transA == oneapi::mkl::transpose::conjtrans); + for (intType row = 0; row < nrows; row++) { + z[row] *= beta; + } + for (intType row = 0; row < nrows; row++) { + fp tmp = alpha * x[row]; + for (intType i = ia[row]; i < ia[row + 1]; i++) { + if constexpr (is_complex()) { + z[ja[i]] += tmp * (isConj ? std::conj(a[i]) : a[i]); + } + else { + z[ja[i]] += tmp * a[i]; + } + } + } + + bool good = true; + for (intType row = 0; row < nrows; row++) { + good &= check_result(res[row], z[row], nrows, row); + } + + std::cout << "\n\t\t sparse::gemv example " << (good ? "passed" : "failed") << "\n\tFinished" + << std::endl; + + free_vec(fp_ptr_vec, main_queue); + free_vec(int_ptr_vec, main_queue); + + if (!good) + return 1; + + return 0; +} + +// +// Description of example setup, apis used and supported floating point type +// precisions +// +void print_example_banner() { + std::cout << "" << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout << "# Sparse Matrix-Vector Multiply Example: " << std::endl; + std::cout << "# " << std::endl; + std::cout << "# y = alpha * op(A) * x + beta * y" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# where A is a sparse matrix in CSR format, x and y are " + "dense vectors" + << std::endl; + std::cout << "# and alpha, beta are floating point type precision scalars." << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Using apis:" << std::endl; + std::cout << "# sparse::gemv" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Using single precision (float) data type" << std::endl; + std::cout << "# " << std::endl; + std::cout << "# Device will be selected during runtime." << std::endl; + std::cout << "# The environment variable ONEAPI_DEVICE_SELECTOR can be used to specify" + << std::endl; + std::cout << "# available devices" << std::endl; + std::cout << "# " << std::endl; + std::cout << "########################################################################" + << std::endl; + std::cout << std::endl; +} + +// +// Main entry point for example +// +int main(int /*argc*/, char ** /*argv*/) { + print_example_banner(); + + try { + sycl::device dev = sycl::device(); + + if (dev.is_gpu()) { + std::cout << "Running Sparse BLAS GEMV USM example on GPU device." << std::endl; + std::cout << "Device name is: " << dev.get_info() + << std::endl; + } + else { + std::cout << "Running Sparse BLAS GEMV USM example on CPU device." << std::endl; + std::cout << "Device name is: " << dev.get_info() + << std::endl; + } + std::cout << "Running with single precision real data type:" << std::endl; + + run_sparse_matrix_vector_multiply_example(dev); + std::cout << "Sparse BLAS GEMV USM example ran OK." << std::endl; + } + catch (sycl::exception const &e) { + std::cerr << "Caught synchronous SYCL exception during Sparse GEMV:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + std::cerr << "\tSYCL error code: " << e.code().value() << std::endl; + return 1; + } + catch (std::exception const &e) { + std::cerr << "Caught std::exception during Sparse GEMV:" << std::endl; + std::cerr << "\t" << e.what() << std::endl; + return 1; + } + + return 0; +} diff --git a/include/oneapi/mkl.hpp b/include/oneapi/mkl.hpp index a49c1ceda..f3e9b8618 100644 --- a/include/oneapi/mkl.hpp +++ b/include/oneapi/mkl.hpp @@ -26,5 +26,6 @@ #include "oneapi/mkl/dft.hpp" #include "oneapi/mkl/lapack.hpp" #include "oneapi/mkl/rng.hpp" +#include "oneapi/mkl/sparse_blas.hpp" #endif //_ONEMKL_HPP_ diff --git a/include/oneapi/mkl/blas.hpp b/include/oneapi/mkl/blas.hpp index 1dbcaf2b0..05458d9aa 100644 --- a/include/oneapi/mkl/blas.hpp +++ b/include/oneapi/mkl/blas.hpp @@ -49,6 +49,9 @@ #ifdef ENABLE_NETLIB_BACKEND #include "oneapi/mkl/blas/detail/netlib/blas_ct.hpp" #endif +#ifdef ENABLE_PORTBLAS_BACKEND +#include "oneapi/mkl/blas/detail/portblas/blas_ct.hpp" +#endif namespace oneapi { namespace mkl { diff --git a/include/oneapi/mkl/blas.hxx b/include/oneapi/mkl/blas.hxx index 4ae7053a8..374585912 100644 --- a/include/oneapi/mkl/blas.hxx +++ b/include/oneapi/mkl/blas.hxx @@ -382,6 +382,39 @@ static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose tr stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); } +static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + +static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + +static inline void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + static inline void gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, @@ -1614,6 +1647,40 @@ static inline void omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, detail::omatcopy(get_device_id(queue), queue, trans, m, n, alpha, a, lda, b, ldb); } +static inline void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + detail::omatcopy2(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stridea, b, ldb, + strideb); +} + +static inline void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + detail::omatcopy2(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stridea, b, ldb, + strideb); +} + +static inline void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t strideb) { + detail::omatcopy2(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stridea, b, ldb, + strideb); +} + +static inline void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t strideb) { + detail::omatcopy2(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stridea, b, ldb, + strideb); +} + static inline void imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb) { @@ -2212,6 +2279,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, return done; } +static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, + const sycl::half **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}) { + auto done = + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, group_count, group_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}) { + auto done = + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, group_count, group_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}) { + auto done = + detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, lda, b, + ldb, beta, c, ldc, group_count, group_size, dependencies); + return done; +} + static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, @@ -2278,6 +2384,45 @@ static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, trans return done; } +static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}) { + auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size, dependencies); + return done; +} + +static inline sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}) { + auto done = detail::gemm_batch(get_device_id(queue), queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, + batch_size, dependencies); + return done; +} + static inline sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, @@ -4056,6 +4201,48 @@ static inline sycl::event omatcopy(sycl::queue &queue, transpose trans, std::int return done; } +static inline sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}) { + auto done = detail::omatcopy2(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stridea, + b, ldb, strideb, dependencies); + return done; +} + +static inline sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stridea, double *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}) { + auto done = detail::omatcopy2(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stridea, + b, ldb, strideb, dependencies); + return done; +} + +static inline sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}) { + auto done = detail::omatcopy2(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stridea, + b, ldb, strideb, dependencies); + return done; +} + +static inline sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}) { + auto done = detail::omatcopy2(get_device_id(queue), queue, trans, m, n, alpha, a, lda, stridea, + b, ldb, strideb, dependencies); + return done; +} + static inline sycl::event imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp b/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp index b58139b32..eb894b5b9 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hpp @@ -51,6 +51,9 @@ namespace column_major { #define BACKEND netlib #include "blas_ct_backends.hxx" #undef BACKEND +#define BACKEND portblas +#include "blas_ct_backends.hxx" +#undef BACKEND } //namespace column_major namespace row_major { @@ -70,6 +73,9 @@ namespace row_major { #define BACKEND netlib #include "blas_ct_backends.hxx" #undef BACKEND +#define BACKEND portblas +#include "blas_ct_backends.hxx" +#undef BACKEND } //namespace row_major } //namespace blas diff --git a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx index 6f6071f97..afebb93c3 100644 --- a/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx +++ b/include/oneapi/mkl/blas/detail/blas_ct_backends.hxx @@ -464,6 +464,30 @@ static inline void gemm_batch(backend_selector selector, trans sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +static inline void gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + +static inline void gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + +static inline void gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + static inline void spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, sycl::buffer &a, sycl::buffer &x, std::int64_t incx, float beta, @@ -1178,6 +1202,28 @@ static inline void omatcopy(backend_selector selector, transpo sycl::buffer, 1> &a, std::int64_t lda, sycl::buffer, 1> &b, std::int64_t ldb); +static inline void omatcopy2(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer &b, + std::int64_t ldb, std::int64_t strideb); + +static inline void omatcopy2(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &b, std::int64_t ldb, std::int64_t strideb); + +static inline void omatcopy2(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb); + +static inline void omatcopy2(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb); + static inline void imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb); @@ -1848,6 +1894,30 @@ static inline sycl::event gemm_batch(backend_selector selector std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +static inline sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, + float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, + float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, + float *beta, std::int32_t **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + static inline sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -1889,6 +1959,33 @@ static inline sycl::event gemm_batch( sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); +static inline sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +static inline sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + static inline sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, @@ -2736,6 +2833,32 @@ static inline sycl::event omatcopy(backend_selector selector, std::complex *b, std::int64_t ldb, const std::vector &dependencies = {}); +static inline sycl::event omatcopy2(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stridea, float *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); + +static inline sycl::event omatcopy2(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stridea, double *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); + +static inline sycl::event omatcopy2(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}); + +static inline sycl::event omatcopy2(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}); + static inline sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, diff --git a/include/oneapi/mkl/blas/detail/blas_loader.hxx b/include/oneapi/mkl/blas/detail/blas_loader.hxx index 86c5499dd..98d93b2ad 100644 --- a/include/oneapi/mkl/blas/detail/blas_loader.hxx +++ b/include/oneapi/mkl/blas/detail/blas_loader.hxx @@ -124,6 +124,27 @@ ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, tr std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); ONEMKL_EXPORT void syrk(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, @@ -1043,6 +1064,25 @@ ONEMKL_EXPORT void omatcopy(oneapi::mkl::device libkey, sycl::queue &queue, tran sycl::buffer, 1> &a, std::int64_t lda, sycl::buffer, 1> &b, std::int64_t ldb); +ONEMKL_EXPORT void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer &b, + std::int64_t ldb, std::int64_t strideb); +ONEMKL_EXPORT void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &b, std::int64_t ldb, std::int64_t strideb); +ONEMKL_EXPORT void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb); +ONEMKL_EXPORT void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb); + ONEMKL_EXPORT void imatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb); @@ -1208,6 +1248,29 @@ ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &qu sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -1244,6 +1307,30 @@ ONEMKL_EXPORT sycl::event gemm_batch( std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, + transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + float beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); ONEMKL_EXPORT sycl::event syrk(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, @@ -2494,6 +2581,29 @@ ONEMKL_EXPORT sycl::event omatcopy(oneapi::mkl::device libkey, sycl::queue &queu std::complex *b, std::int64_t ldb, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stridea, float *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stridea, double *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}); + ONEMKL_EXPORT sycl::event imatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, diff --git a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx index a884cb614..9483a66c1 100644 --- a/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/blas_ct.hxx @@ -186,6 +186,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::cublas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::cublas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::cublas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -1626,6 +1659,38 @@ void omatcopy(backend_selector selector, transpose trans, std:: ldb); } +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::cublas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::cublas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::cublas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::cublas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + void imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb) { @@ -2638,6 +2703,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -2705,6 +2806,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, @@ -4039,6 +4176,44 @@ sycl::event omatcopy(backend_selector selector, transpose trans return done; } +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stridea, double *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::cublas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, const std::vector &dependencies) { diff --git a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx index 40e30b156..1141eb238 100644 --- a/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx +++ b/include/oneapi/mkl/blas/detail/cublas/onemkl_blas_cublas.hxx @@ -804,6 +804,25 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -947,6 +966,23 @@ void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co sycl::buffer, 1> &a, int64_t lda, sycl::buffer, 1> &b, int64_t ldb); +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb); + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb); + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, std::int64_t stridea, + sycl::buffer, 1> &b, int64_t ldb, std::int64_t strideb); + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, int64_t ldb, + std::int64_t strideb); + void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb); @@ -2023,6 +2059,24 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies = {}); + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, const float *b, @@ -2064,6 +2118,27 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t batch_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, std::int32_t *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); + sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, @@ -2197,6 +2272,24 @@ sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex *b, int64_t ldb, const std::vector &dependencies = {}); +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, std::int64_t stridea, float *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies = {}); + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, std::int64_t stridea, double *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies = {}); + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + std::int64_t stridea, std::complex *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies = {}); + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + std::int64_t stridea, std::complex *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies = {}); + sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, const std::vector &dependencies = {}); diff --git a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx index 94b0987f3..1724bf5c7 100644 --- a/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklcpu/blas_ct.hxx @@ -188,6 +188,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -1628,6 +1661,38 @@ void omatcopy(backend_selector selector, transpose trans, std:: ldb); } +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::mklcpu::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::mklcpu::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::mklcpu::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::mklcpu::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + void imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb) { @@ -2640,6 +2705,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -2707,6 +2808,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, @@ -4041,6 +4178,44 @@ sycl::event omatcopy(backend_selector selector, transpose trans return done; } +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stridea, double *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklcpu::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, const std::vector &dependencies) { diff --git a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx index 2f0f88fd2..c69257e9c 100644 --- a/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/mklgpu/blas_ct.hxx @@ -188,6 +188,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -1628,6 +1661,38 @@ void omatcopy(backend_selector selector, transpose trans, std:: ldb); } +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::mklgpu::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::mklgpu::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::mklgpu::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::mklgpu::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + void imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb) { @@ -2590,6 +2655,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose *transa, transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const float **a, std::int64_t *lda, const float **b, @@ -2653,6 +2754,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -4041,6 +4178,44 @@ sycl::event omatcopy(backend_selector selector, transpose trans return done; } +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stridea, double *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::mklgpu::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, const std::vector &dependencies) { diff --git a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx index 6e1257ee1..404d79ae0 100644 --- a/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/netlib/blas_ct.hxx @@ -188,6 +188,39 @@ void gemm_batch(backend_selector selector, transpose transa, tr ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::netlib::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, + ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, float beta, sycl::buffer &c, std::int64_t ldc) { @@ -1628,6 +1661,38 @@ void omatcopy(backend_selector selector, transpose trans, std:: ldb); } +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::netlib::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::netlib::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::netlib::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::netlib::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + void imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb) { @@ -2640,6 +2705,42 @@ sycl::event gemm_batch(backend_selector selector, transpose *tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -2707,6 +2808,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tra return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, float alpha, const float *a, const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, @@ -4046,6 +4183,44 @@ sycl::event omatcopy(backend_selector selector, transpose trans return done; } +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stridea, double *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::netlib::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, const std::vector &dependencies) { diff --git a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx index 3c129cdbc..fbb64a6a0 100644 --- a/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx +++ b/include/oneapi/mkl/blas/detail/onemkl_blas_backends.hxx @@ -973,6 +973,30 @@ ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); +ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + +ONEMKL_EXPORT void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + ONEMKL_EXPORT void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, @@ -1158,6 +1182,28 @@ ONEMKL_EXPORT void omatcopy(sycl::queue &queue, oneapi::mkl::transpose trans, st sycl::buffer, 1> &a, std::int64_t lda, sycl::buffer, 1> &b, std::int64_t ldb); +ONEMKL_EXPORT void omatcopy2(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer &b, + std::int64_t ldb, std::int64_t strideb); + +ONEMKL_EXPORT void omatcopy2(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer &b, + std::int64_t ldb, std::int64_t strideb); + +ONEMKL_EXPORT void omatcopy2(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb); + +ONEMKL_EXPORT void omatcopy2(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb); + ONEMKL_EXPORT void imatcopy(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb); @@ -2536,6 +2582,32 @@ ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, + std::int64_t *n, std::int64_t *k, float *alpha, + const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, + std::int32_t **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, + const std::vector &dependencies = {}); + ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, @@ -2577,6 +2649,33 @@ ONEMKL_EXPORT sycl::event gemm_batch( std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, + std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, + const std::vector &dependencies = {}); + ONEMKL_EXPORT sycl::event gemmt(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, @@ -2698,6 +2797,7 @@ ONEMKL_EXPORT sycl::event omatadd_batch( const std::complex *b, std::int64_t ldb, std::int64_t stride_b, std::complex *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); + ONEMKL_EXPORT sycl::event omatcopy(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, float alpha, const float *a, std::int64_t lda, float *b, std::int64_t ldb, @@ -2719,6 +2819,33 @@ ONEMKL_EXPORT sycl::event omatcopy(sycl::queue &queue, oneapi::mkl::transpose tr const std::complex *a, std::int64_t lda, std::complex *b, std::int64_t ldb, const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatcopy2(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stridea, float *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatcopy2(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stridea, double *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatcopy2(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event omatcopy2(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies = {}); + ONEMKL_EXPORT sycl::event imatcopy(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, diff --git a/include/oneapi/mkl/blas/detail/portblas/blas_ct.hpp b/include/oneapi/mkl/blas/detail/portblas/blas_ct.hpp new file mode 100644 index 000000000..6d3b0b2c2 --- /dev/null +++ b/include/oneapi/mkl/blas/detail/portblas/blas_ct.hpp @@ -0,0 +1,57 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _DETAIL_PORTBLAS_BLAS_CT_HPP_ +#define _DETAIL_PORTBLAS_BLAS_CT_HPP_ + +#if __has_include() +#include +#else +#include +#endif +#include +#include + +#include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/detail/backend_selector.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" +#include "oneapi/mkl/blas/detail/blas_ct_backends.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace column_major { + +#define MAJOR column_major +#include "blas_ct.hxx" +#undef MAJOR + +} //namespace column_major +namespace row_major { + +#define MAJOR row_major +#include "blas_ct.hxx" +#undef MAJOR + +} //namespace row_major +} //namespace blas +} //namespace mkl +} //namespace oneapi + +#endif //_DETAIL_PORTBLAS_BLAS_CT_HPP_ diff --git a/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx new file mode 100644 index 000000000..8a66ed707 --- /dev/null +++ b/include/oneapi/mkl/blas/detail/portblas/blas_ct.hxx @@ -0,0 +1,4296 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +// Buffer APIs + +void herk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer, 1> &a, + std::int64_t lda, float beta, sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::herk(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, beta, c, ldc); +} + +void herk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, double alpha, sycl::buffer, 1> &a, + std::int64_t lda, double beta, sycl::buffer, 1> &c, + std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::herk(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, beta, c, ldc); +} + +void scal(backend_selector selector, std::int64_t n, float alpha, + sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx); +} + +void scal(backend_selector selector, std::int64_t n, double alpha, + sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx); +} + +void scal(backend_selector selector, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx); +} + +void scal(backend_selector selector, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx); +} + +void scal(backend_selector selector, std::int64_t n, float alpha, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx); +} + +void scal(backend_selector selector, std::int64_t n, double alpha, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx); +} + +void trmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::trmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, lda, x, incx); +} + +void trmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::trmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, lda, x, incx); +} + +void trmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::trmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, lda, x, incx); +} + +void trmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::trmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, lda, x, incx); +} + +void tpmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer &a, sycl::buffer &x, + std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tpmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, x, incx); +} + +void tpmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer &a, sycl::buffer &x, + std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tpmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, x, incx); +} + +void tpmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tpmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, x, incx); +} + +void tpmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tpmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, x, incx); +} + +void spr(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, sycl::buffer &x, std::int64_t incx, sycl::buffer &a) { + oneapi::mkl::blas::portblas::MAJOR::spr(selector.get_queue(), upper_lower, n, alpha, x, incx, + a); +} + +void spr(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, sycl::buffer &x, std::int64_t incx, sycl::buffer &a) { + oneapi::mkl::blas::portblas::MAJOR::spr(selector.get_queue(), upper_lower, n, alpha, x, incx, + a); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, double alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, double beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, std::int64_t stride_b, + std::complex beta, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, std::int64_t stride_b, + std::complex beta, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + sycl::half beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void syrk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, + float beta, sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::syrk(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, beta, c, ldc); +} + +void syrk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, double alpha, sycl::buffer &a, + std::int64_t lda, double beta, sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::syrk(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, beta, c, ldc); +} + +void syrk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::syrk(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, beta, c, ldc); +} + +void syrk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::syrk(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, beta, c, ldc); +} + +void syrk_batch(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::syrk_batch(selector.get_queue(), upper_lower, trans, n, k, + alpha, a, lda, stride_a, beta, c, ldc, stride_c, + batch_size); +} + +void syrk_batch(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, double beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::syrk_batch(selector.get_queue(), upper_lower, trans, n, k, + alpha, a, lda, stride_a, beta, c, ldc, stride_c, + batch_size); +} + +void syrk_batch(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::syrk_batch(selector.get_queue(), upper_lower, trans, n, k, + alpha, a, lda, stride_a, beta, c, ldc, stride_c, + batch_size); +} + +void syrk_batch(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::syrk_batch(selector.get_queue(), upper_lower, trans, n, k, + alpha, a, lda, stride_a, beta, c, ldc, stride_c, + batch_size); +} + +void her2(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::her2(selector.get_queue(), upper_lower, n, alpha, x, incx, + y, incy, a, lda); +} + +void her2(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::her2(selector.get_queue(), upper_lower, n, alpha, x, incx, + y, incy, a, lda); +} + +void hbmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::int64_t k, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::hbmv(selector.get_queue(), upper_lower, n, k, alpha, a, lda, + x, incx, beta, y, incy); +} + +void hbmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::int64_t k, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::hbmv(selector.get_queue(), upper_lower, n, k, alpha, a, lda, + x, incx, beta, y, incy); +} + +void rot(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, float c, float s) { + oneapi::mkl::blas::portblas::MAJOR::rot(selector.get_queue(), n, x, incx, y, incy, c, s); +} + +void rot(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, double c, double s) { + oneapi::mkl::blas::portblas::MAJOR::rot(selector.get_queue(), n, x, incx, y, incy, c, s); +} + +void rot(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy, float c, float s) { + oneapi::mkl::blas::portblas::MAJOR::rot(selector.get_queue(), n, x, incx, y, incy, c, s); +} + +void rot(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy, double c, double s) { + oneapi::mkl::blas::portblas::MAJOR::rot(selector.get_queue(), n, x, incx, y, incy, c, s); +} + +void axpy(backend_selector selector, std::int64_t n, float alpha, + sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::axpy(selector.get_queue(), n, alpha, x, incx, y, incy); +} + +void axpy(backend_selector selector, std::int64_t n, double alpha, + sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::axpy(selector.get_queue(), n, alpha, x, incx, y, incy); +} + +void axpy(backend_selector selector, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::axpy(selector.get_queue(), n, alpha, x, incx, y, incy); +} + +void axpy(backend_selector selector, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::axpy(selector.get_queue(), n, alpha, x, incx, y, incy); +} + +void axpy_batch(backend_selector selector, std::int64_t n, float alpha, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::axpy_batch(selector.get_queue(), n, alpha, x, incx, stridex, + y, incy, stridey, batch_size); +} + +void axpy_batch(backend_selector selector, std::int64_t n, double alpha, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::axpy_batch(selector.get_queue(), n, alpha, x, incx, stridex, + y, incy, stridey, batch_size); +} + +void axpy_batch(backend_selector selector, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, + std::int64_t incx, std::int64_t stridex, sycl::buffer, 1> &y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::axpy_batch(selector.get_queue(), n, alpha, x, incx, stridex, + y, incy, stridey, batch_size); +} + +void axpy_batch(backend_selector selector, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, + std::int64_t incx, std::int64_t stridex, sycl::buffer, 1> &y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::axpy_batch(selector.get_queue(), n, alpha, x, incx, stridex, + y, incy, stridey, batch_size); +} + +void axpby(backend_selector selector, std::int64_t n, float alpha, + sycl::buffer &x, std::int64_t incx, float beta, sycl::buffer &y, + std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::axpby(selector.get_queue(), n, alpha, x, incx, beta, y, + incy); +} + +void axpby(backend_selector selector, std::int64_t n, double alpha, + sycl::buffer &x, std::int64_t incx, double beta, sycl::buffer &y, + std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::axpby(selector.get_queue(), n, alpha, x, incx, beta, y, + incy); +} + +void axpby(backend_selector selector, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::axpby(selector.get_queue(), n, alpha, x, incx, beta, y, + incy); +} + +void axpby(backend_selector selector, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::axpby(selector.get_queue(), n, alpha, x, incx, beta, y, + incy); +} + +void sdsdot(backend_selector selector, std::int64_t n, float sb, + sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy, sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::sdsdot(selector.get_queue(), n, sb, x, incx, y, incy, + result); +} + +void gerc(backend_selector selector, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::gerc(selector.get_queue(), m, n, alpha, x, incx, y, incy, a, + lda); +} + +void gerc(backend_selector selector, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::gerc(selector.get_queue(), m, n, alpha, x, incx, y, incy, a, + lda); +} + +void syr2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, + sycl::buffer &b, std::int64_t ldb, float beta, sycl::buffer &c, + std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::syr2k(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void syr2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, double alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &b, std::int64_t ldb, double beta, + sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::syr2k(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void syr2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::syr2k(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void syr2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::syr2k(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void gemv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx, float beta, sycl::buffer &y, + std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::gemv(selector.get_queue(), trans, m, n, alpha, a, lda, x, + incx, beta, y, incy); +} + +void gemv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx, double beta, sycl::buffer &y, + std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::gemv(selector.get_queue(), trans, m, n, alpha, a, lda, x, + incx, beta, y, incy); +} + +void gemv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::gemv(selector.get_queue(), trans, m, n, alpha, a, lda, x, + incx, beta, y, incy); +} + +void gemv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx, + std::complex beta, sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::gemv(selector.get_queue(), trans, m, n, alpha, a, lda, x, + incx, beta, y, incy); +} + +void gemv_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &x, std::int64_t incx, + std::int64_t stridex, float beta, sycl::buffer &y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemv_batch(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, x, incx, stridex, beta, y, incy, + stridey, batch_size); +} + +void gemv_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &x, std::int64_t incx, + std::int64_t stridex, double beta, sycl::buffer &y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemv_batch(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, x, incx, stridex, beta, y, incy, + stridey, batch_size); +} + +void gemv_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &x, + std::int64_t incx, std::int64_t stridex, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemv_batch(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, x, incx, stridex, beta, y, incy, + stridey, batch_size); +} + +void gemv_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer, 1> &x, std::int64_t incx, std::int64_t stridex, + std::complex beta, sycl::buffer, 1> &y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::gemv_batch(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, x, incx, stridex, beta, y, incy, + stridey, batch_size); +} + +void dgmm_batch(backend_selector selector, side left_right, std::int64_t m, + std::int64_t n, sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer &c, std::int64_t ldc, std::int64_t stridec, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::dgmm_batch(selector.get_queue(), left_right, m, n, a, lda, + stridea, x, incx, stridex, c, ldc, stridec, + batch_size); +} + +void dgmm_batch(backend_selector selector, side left_right, std::int64_t m, + std::int64_t n, sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer &c, std::int64_t ldc, std::int64_t stridec, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::dgmm_batch(selector.get_queue(), left_right, m, n, a, lda, + stridea, x, incx, stridex, c, ldc, stridec, + batch_size); +} + +void dgmm_batch(backend_selector selector, side left_right, std::int64_t m, + std::int64_t n, sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &x, std::int64_t incx, + std::int64_t stridex, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stridec, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::dgmm_batch(selector.get_queue(), left_right, m, n, a, lda, + stridea, x, incx, stridex, c, ldc, stridec, + batch_size); +} + +void dgmm_batch(backend_selector selector, side left_right, std::int64_t m, + std::int64_t n, sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &x, std::int64_t incx, + std::int64_t stridex, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stridec, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::dgmm_batch(selector.get_queue(), left_right, m, n, a, lda, + stridea, x, incx, stridex, c, ldc, stridec, + batch_size); +} + +void her(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::her(selector.get_queue(), upper_lower, n, alpha, x, incx, a, + lda); +} + +void her(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::her(selector.get_queue(), upper_lower, n, alpha, x, incx, a, + lda); +} + +void hpr(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &a) { + oneapi::mkl::blas::portblas::MAJOR::hpr(selector.get_queue(), upper_lower, n, alpha, x, incx, + a); +} + +void hpr(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &a) { + oneapi::mkl::blas::portblas::MAJOR::hpr(selector.get_queue(), upper_lower, n, alpha, x, incx, + a); +} + +void iamin(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::iamin(selector.get_queue(), n, x, incx, result); +} + +void iamin(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::iamin(selector.get_queue(), n, x, incx, result); +} + +void iamin(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::iamin(selector.get_queue(), n, x, incx, result); +} + +void iamin(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::iamin(selector.get_queue(), n, x, incx, result); +} + +void hpmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::hpmv(selector.get_queue(), upper_lower, n, alpha, a, x, + incx, beta, y, incy); +} + +void hpmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::hpmv(selector.get_queue(), upper_lower, n, alpha, a, x, + incx, beta, y, incy); +} + +void spmv(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, sycl::buffer &a, sycl::buffer &x, std::int64_t incx, + float beta, sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::spmv(selector.get_queue(), upper_lower, n, alpha, a, x, + incx, beta, y, incy); +} + +void spmv(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, sycl::buffer &a, sycl::buffer &x, std::int64_t incx, + double beta, sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::spmv(selector.get_queue(), upper_lower, n, alpha, a, x, + incx, beta, y, incy); +} + +void gemm_bias(backend_selector selector, transpose transa, transpose transb, + offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, int8_t ao, sycl::buffer &b, + std::int64_t ldb, uint8_t bo, float beta, sycl::buffer &c, + std::int64_t ldc, sycl::buffer &co) { + oneapi::mkl::blas::portblas::MAJOR::gemm_bias(selector.get_queue(), transa, transb, offsetc, m, + n, k, alpha, a, lda, ao, b, ldb, bo, beta, c, ldc, + co); +} + +void gemm_bias(backend_selector selector, transpose transa, transpose transb, + offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, int8_t ao, sycl::buffer &b, + std::int64_t ldb, int8_t bo, float beta, sycl::buffer &c, + std::int64_t ldc, sycl::buffer &co) { + oneapi::mkl::blas::portblas::MAJOR::gemm_bias(selector.get_queue(), transa, transb, offsetc, m, + n, k, alpha, a, lda, ao, b, ldb, bo, beta, c, ldc, + co); +} + +void gemm_bias(backend_selector selector, transpose transa, transpose transb, + offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, uint8_t ao, + sycl::buffer &b, std::int64_t ldb, int8_t bo, float beta, + sycl::buffer &c, std::int64_t ldc, sycl::buffer &co) { + oneapi::mkl::blas::portblas::MAJOR::gemm_bias(selector.get_queue(), transa, transb, offsetc, m, + n, k, alpha, a, lda, ao, b, ldb, bo, beta, c, ldc, + co); +} + +void gemm_bias(backend_selector selector, transpose transa, transpose transb, + offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, uint8_t ao, + sycl::buffer &b, std::int64_t ldb, uint8_t bo, float beta, + sycl::buffer &c, std::int64_t ldc, sycl::buffer &co) { + oneapi::mkl::blas::portblas::MAJOR::gemm_bias(selector.get_queue(), transa, transb, offsetc, m, + n, k, alpha, a, lda, ao, b, ldb, bo, beta, c, ldc, + co); +} + +void swap(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::swap(selector.get_queue(), n, x, incx, y, incy); +} + +void swap(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::swap(selector.get_queue(), n, x, incx, y, incy); +} + +void swap(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::swap(selector.get_queue(), n, x, incx, y, incy); +} + +void swap(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::swap(selector.get_queue(), n, x, incx, y, incy); +} + +void geru(backend_selector selector, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::geru(selector.get_queue(), m, n, alpha, x, incx, y, incy, a, + lda); +} + +void geru(backend_selector selector, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::geru(selector.get_queue(), m, n, alpha, x, incx, y, incy, a, + lda); +} + +void nrm2(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::nrm2(selector.get_queue(), n, x, incx, result); +} + +void nrm2(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::nrm2(selector.get_queue(), n, x, incx, result); +} + +void nrm2(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::nrm2(selector.get_queue(), n, x, incx, result); +} + +void nrm2(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::nrm2(selector.get_queue(), n, x, incx, result); +} + +void gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &b, std::int64_t ldb, float beta, + sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, double alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &b, std::int64_t ldb, double beta, + sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb, sycl::half beta, sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb, float beta, sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &b, std::int64_t ldb, float beta, + sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy, sycl::buffer &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::syr2(selector.get_queue(), upper_lower, n, alpha, x, incx, + y, incy, a, lda); +} + +void syr2(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy, sycl::buffer &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::syr2(selector.get_queue(), upper_lower, n, alpha, x, incx, + y, incy, a, lda); +} + +void ger(backend_selector selector, std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &x, std::int64_t incx, sycl::buffer &y, std::int64_t incy, + sycl::buffer &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::ger(selector.get_queue(), m, n, alpha, x, incx, y, incy, a, + lda); +} + +void ger(backend_selector selector, std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy, sycl::buffer &a, std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::ger(selector.get_queue(), m, n, alpha, x, incx, y, incy, a, + lda); +} + +void trsm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::trsm(selector.get_queue(), left_right, upper_lower, trans, + unit_diag, m, n, alpha, a, lda, b, ldb); +} + +void trsm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::trsm(selector.get_queue(), left_right, upper_lower, trans, + unit_diag, m, n, alpha, a, lda, b, ldb); +} + +void trsm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::trsm(selector.get_queue(), left_right, upper_lower, trans, + unit_diag, m, n, alpha, a, lda, b, ldb); +} + +void trsm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::trsm(selector.get_queue(), left_right, upper_lower, trans, + unit_diag, m, n, alpha, a, lda, b, ldb); +} + +void dotu(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &result) { + oneapi::mkl::blas::portblas::MAJOR::dotu(selector.get_queue(), n, x, incx, y, incy, result); +} + +void dotu(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &result) { + oneapi::mkl::blas::portblas::MAJOR::dotu(selector.get_queue(), n, x, incx, y, incy, result); +} + +void hemm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::hemm(selector.get_queue(), left_right, upper_lower, m, n, + alpha, a, lda, b, ldb, beta, c, ldc); +} + +void hemm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::hemm(selector.get_queue(), left_right, upper_lower, m, n, + alpha, a, lda, b, ldb, beta, c, ldc); +} + +void hpr2(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a) { + oneapi::mkl::blas::portblas::MAJOR::hpr2(selector.get_queue(), upper_lower, n, alpha, x, incx, + y, incy, a); +} + +void hpr2(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a) { + oneapi::mkl::blas::portblas::MAJOR::hpr2(selector.get_queue(), upper_lower, n, alpha, x, incx, + y, incy, a); +} + +void gbmv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::int64_t kl, std::int64_t ku, float alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &x, std::int64_t incx, float beta, + sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::gbmv(selector.get_queue(), trans, m, n, kl, ku, alpha, a, + lda, x, incx, beta, y, incy); +} + +void gbmv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::int64_t kl, std::int64_t ku, double alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &x, + std::int64_t incx, double beta, sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::gbmv(selector.get_queue(), trans, m, n, kl, ku, alpha, a, + lda, x, incx, beta, y, incy); +} + +void gbmv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::int64_t kl, std::int64_t ku, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::gbmv(selector.get_queue(), trans, m, n, kl, ku, alpha, a, + lda, x, incx, beta, y, incy); +} + +void gbmv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::int64_t kl, std::int64_t ku, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::gbmv(selector.get_queue(), trans, m, n, kl, ku, alpha, a, + lda, x, incx, beta, y, incy); +} + +void tbmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, + std::int64_t lda, sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tbmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + k, a, lda, x, incx); +} + +void tbmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, + std::int64_t lda, sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tbmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + k, a, lda, x, incx); +} + +void tbmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tbmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + k, a, lda, x, incx); +} + +void tbmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tbmv(selector.get_queue(), upper_lower, trans, unit_diag, n, + k, a, lda, x, incx); +} + +void symm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + sycl::buffer &b, std::int64_t ldb, float beta, sycl::buffer &c, + std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::symm(selector.get_queue(), left_right, upper_lower, m, n, + alpha, a, lda, b, ldb, beta, c, ldc); +} + +void symm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &b, std::int64_t ldb, double beta, + sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::symm(selector.get_queue(), left_right, upper_lower, m, n, + alpha, a, lda, b, ldb, beta, c, ldc); +} + +void symm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::symm(selector.get_queue(), left_right, upper_lower, m, n, + alpha, a, lda, b, ldb, beta, c, ldc); +} + +void symm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::symm(selector.get_queue(), left_right, upper_lower, m, n, + alpha, a, lda, b, ldb, beta, c, ldc); +} + +void dotc(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &result) { + oneapi::mkl::blas::portblas::MAJOR::dotc(selector.get_queue(), n, x, incx, y, incy, result); +} + +void dotc(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &result) { + oneapi::mkl::blas::portblas::MAJOR::dotc(selector.get_queue(), n, x, incx, y, incy, result); +} + +void syr(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, sycl::buffer &x, std::int64_t incx, sycl::buffer &a, + std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::syr(selector.get_queue(), upper_lower, n, alpha, x, incx, a, + lda); +} + +void syr(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, sycl::buffer &x, std::int64_t incx, sycl::buffer &a, + std::int64_t lda) { + oneapi::mkl::blas::portblas::MAJOR::syr(selector.get_queue(), upper_lower, n, alpha, x, incx, a, + lda); +} + +void trmm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::trmm(selector.get_queue(), left_right, upper_lower, trans, + unit_diag, m, n, alpha, a, lda, b, ldb); +} + +void trmm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::trmm(selector.get_queue(), left_right, upper_lower, trans, + unit_diag, m, n, alpha, a, lda, b, ldb); +} + +void trmm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::trmm(selector.get_queue(), left_right, upper_lower, trans, + unit_diag, m, n, alpha, a, lda, b, ldb); +} + +void trmm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::trmm(selector.get_queue(), left_right, upper_lower, trans, + unit_diag, m, n, alpha, a, lda, b, ldb); +} + +void rotmg(backend_selector selector, sycl::buffer &d1, + sycl::buffer &d2, sycl::buffer &x1, float y1, + sycl::buffer ¶m) { + oneapi::mkl::blas::portblas::MAJOR::rotmg(selector.get_queue(), d1, d2, x1, y1, param); +} + +void rotmg(backend_selector selector, sycl::buffer &d1, + sycl::buffer &d2, sycl::buffer &x1, double y1, + sycl::buffer ¶m) { + oneapi::mkl::blas::portblas::MAJOR::rotmg(selector.get_queue(), d1, d2, x1, y1, param); +} + +void tpsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer &a, sycl::buffer &x, + std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tpsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, x, incx); +} + +void tpsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer &a, sycl::buffer &x, + std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tpsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, x, incx); +} + +void tpsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tpsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, x, incx); +} + +void tpsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tpsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, x, incx); +} + +void trsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::trsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, lda, x, incx); +} + +void trsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::trsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, lda, x, incx); +} + +void trsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::trsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, lda, x, incx); +} + +void trsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::trsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + a, lda, x, incx); +} + +void copy(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::copy(selector.get_queue(), n, x, incx, y, incy); +} + +void copy(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::copy(selector.get_queue(), n, x, incx, y, incy); +} + +void copy(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::copy(selector.get_queue(), n, x, incx, y, incy); +} + +void copy(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::copy(selector.get_queue(), n, x, incx, y, incy); +} + +void copy_batch(backend_selector selector, std::int64_t n, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::copy_batch(selector.get_queue(), n, x, incx, stridex, y, + incy, stridey, batch_size); +} + +void copy_batch(backend_selector selector, std::int64_t n, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::copy_batch(selector.get_queue(), n, x, incx, stridex, y, + incy, stridey, batch_size); +} + +void copy_batch(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer, 1> &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::copy_batch(selector.get_queue(), n, x, incx, stridex, y, + incy, stridey, batch_size); +} + +void copy_batch(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer, 1> &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::copy_batch(selector.get_queue(), n, x, incx, stridex, y, + incy, stridey, batch_size); +} + +void hemv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::hemv(selector.get_queue(), upper_lower, n, alpha, a, lda, x, + incx, beta, y, incy); +} + +void hemv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::hemv(selector.get_queue(), upper_lower, n, alpha, a, lda, x, + incx, beta, y, incy); +} + +void gemmt(backend_selector selector, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &b, std::int64_t ldb, float beta, + sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemmt(selector.get_queue(), upper_lower, transa, transb, n, + k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void gemmt(backend_selector selector, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, double alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb, double beta, sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemmt(selector.get_queue(), upper_lower, transa, transb, n, + k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void gemmt(backend_selector selector, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemmt(selector.get_queue(), upper_lower, transa, transb, n, + k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void gemmt(backend_selector selector, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::gemmt(selector.get_queue(), upper_lower, transa, transb, n, + k, alpha, a, lda, b, ldb, beta, c, ldc); +} + +void asum(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::asum(selector.get_queue(), n, x, incx, result); +} + +void asum(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::asum(selector.get_queue(), n, x, incx, result); +} + +void asum(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::asum(selector.get_queue(), n, x, incx, result); +} + +void asum(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::asum(selector.get_queue(), n, x, incx, result); +} + +void sbmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx, float beta, sycl::buffer &y, + std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::sbmv(selector.get_queue(), upper_lower, n, k, alpha, a, lda, + x, incx, beta, y, incy); +} + +void sbmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::int64_t k, double alpha, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx, double beta, sycl::buffer &y, + std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::sbmv(selector.get_queue(), upper_lower, n, k, alpha, a, lda, + x, incx, beta, y, incy); +} + +void tbsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, + std::int64_t lda, sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tbsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + k, a, lda, x, incx); +} + +void tbsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, + std::int64_t lda, sycl::buffer &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tbsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + k, a, lda, x, incx); +} + +void tbsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tbsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + k, a, lda, x, incx); +} + +void tbsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx) { + oneapi::mkl::blas::portblas::MAJOR::tbsv(selector.get_queue(), upper_lower, trans, unit_diag, n, + k, a, lda, x, incx); +} + +void spr2(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy, sycl::buffer &a) { + oneapi::mkl::blas::portblas::MAJOR::spr2(selector.get_queue(), upper_lower, n, alpha, x, incx, + y, incy, a); +} + +void spr2(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy, sycl::buffer &a) { + oneapi::mkl::blas::portblas::MAJOR::spr2(selector.get_queue(), upper_lower, n, alpha, x, incx, + y, incy, a); +} + +void iamax(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::iamax(selector.get_queue(), n, x, incx, result); +} + +void iamax(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::iamax(selector.get_queue(), n, x, incx, result); +} + +void iamax(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::iamax(selector.get_queue(), n, x, incx, result); +} + +void iamax(backend_selector selector, std::int64_t n, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::iamax(selector.get_queue(), n, x, incx, result); +} + +void rotm(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy, + sycl::buffer ¶m) { + oneapi::mkl::blas::portblas::MAJOR::rotm(selector.get_queue(), n, x, incx, y, incy, param); +} + +void rotm(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy, + sycl::buffer ¶m) { + oneapi::mkl::blas::portblas::MAJOR::rotm(selector.get_queue(), n, x, incx, y, incy, param); +} + +void dot(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::dot(selector.get_queue(), n, x, incx, y, incy, result); +} + +void dot(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::dot(selector.get_queue(), n, x, incx, y, incy, result); +} + +void dot(backend_selector selector, std::int64_t n, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy, + sycl::buffer &result) { + oneapi::mkl::blas::portblas::MAJOR::dot(selector.get_queue(), n, x, incx, y, incy, result); +} + +void trsm_batch(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::trsm_batch(selector.get_queue(), left_right, upper_lower, + trans, unit_diag, m, n, alpha, a, lda, stride_a, + b, ldb, stride_b, batch_size); +} + +void trsm_batch(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::trsm_batch(selector.get_queue(), left_right, upper_lower, + trans, unit_diag, m, n, alpha, a, lda, stride_a, + b, ldb, stride_b, batch_size); +} + +void trsm_batch(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::trsm_batch(selector.get_queue(), left_right, upper_lower, + trans, unit_diag, m, n, alpha, a, lda, stride_a, + b, ldb, stride_b, batch_size); +} + +void trsm_batch(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::trsm_batch(selector.get_queue(), left_right, upper_lower, + trans, unit_diag, m, n, alpha, a, lda, stride_a, + b, ldb, stride_b, batch_size); +} + +void her2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, float beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::her2k(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void her2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, double beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::her2k(selector.get_queue(), upper_lower, trans, n, k, alpha, + a, lda, b, ldb, beta, c, ldc); +} + +void rotg(backend_selector selector, sycl::buffer &a, + sycl::buffer &b, sycl::buffer &c, sycl::buffer &s) { + oneapi::mkl::blas::portblas::MAJOR::rotg(selector.get_queue(), a, b, c, s); +} + +void rotg(backend_selector selector, sycl::buffer &a, + sycl::buffer &b, sycl::buffer &c, sycl::buffer &s) { + oneapi::mkl::blas::portblas::MAJOR::rotg(selector.get_queue(), a, b, c, s); +} + +void rotg(backend_selector selector, sycl::buffer, 1> &a, + sycl::buffer, 1> &b, sycl::buffer &c, + sycl::buffer, 1> &s) { + oneapi::mkl::blas::portblas::MAJOR::rotg(selector.get_queue(), a, b, c, s); +} + +void rotg(backend_selector selector, sycl::buffer, 1> &a, + sycl::buffer, 1> &b, sycl::buffer &c, + sycl::buffer, 1> &s) { + oneapi::mkl::blas::portblas::MAJOR::rotg(selector.get_queue(), a, b, c, s); +} + +void symv(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, + std::int64_t incx, float beta, sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::symv(selector.get_queue(), upper_lower, n, alpha, a, lda, x, + incx, beta, y, incy); +} + +void symv(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, + std::int64_t incx, double beta, sycl::buffer &y, std::int64_t incy) { + oneapi::mkl::blas::portblas::MAJOR::symv(selector.get_queue(), upper_lower, n, alpha, a, lda, x, + incx, beta, y, incy); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); +} + +void omatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch(selector.get_queue(), trans, m, n, alpha, a, + lda, stride_a, b, ldb, stride_b, batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void imatcopy_batch(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch(selector.get_queue(), trans, m, n, alpha, ab, + lda, ldb, stride, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, float beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + oneapi::mkl::blas::portblas::MAJOR::omatadd_batch(selector.get_queue(), transa, transb, m, n, + alpha, a, lda, stride_a, beta, b, ldb, + stride_b, c, ldc, stride_c, batch_size); +} + +void omatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + sycl::buffer &b, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy(selector.get_queue(), trans, m, n, alpha, a, lda, + b, ldb); +} + +void omatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + sycl::buffer &b, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy(selector.get_queue(), trans, m, n, alpha, a, lda, + b, ldb); +} + +void omatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &b, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy(selector.get_queue(), trans, m, n, alpha, a, lda, + b, ldb); +} + +void omatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &b, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy(selector.get_queue(), trans, m, n, alpha, a, lda, + b, ldb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::portblas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void imatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::imatcopy(selector.get_queue(), trans, m, n, alpha, ab, lda, + ldb); +} + +void imatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::imatcopy(selector.get_queue(), trans, m, n, alpha, ab, lda, + ldb); +} + +void imatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::imatcopy(selector.get_queue(), trans, m, n, alpha, ab, lda, + ldb); +} + +void imatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb) { + oneapi::mkl::blas::portblas::MAJOR::imatcopy(selector.get_queue(), trans, m, n, alpha, ab, lda, + ldb); +} + +void omatadd(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, float beta, sycl::buffer &b, std::int64_t ldb, + sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::omatadd(selector.get_queue(), transa, transb, m, n, alpha, + a, lda, beta, b, ldb, c, ldc); +} + +void omatadd(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, double beta, sycl::buffer &b, std::int64_t ldb, + sycl::buffer &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::omatadd(selector.get_queue(), transa, transb, m, n, alpha, + a, lda, beta, b, ldb, c, ldc); +} + +void omatadd(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::omatadd(selector.get_queue(), transa, transb, m, n, alpha, + a, lda, beta, b, ldb, c, ldc); +} + +void omatadd(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + sycl::buffer, 1> &c, std::int64_t ldc) { + oneapi::mkl::blas::portblas::MAJOR::omatadd(selector.get_queue(), transa, transb, m, n, alpha, + a, lda, beta, b, ldb, c, ldc); +} + +// USM APIs + +sycl::event syr2(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, const float *x, std::int64_t incx, const float *y, std::int64_t incy, + float *a, std::int64_t lda, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syr2( + selector.get_queue(), upper_lower, n, alpha, x, incx, y, incy, a, lda, dependencies); + return done; +} + +sycl::event syr2(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, const double *x, std::int64_t incx, const double *y, + std::int64_t incy, double *a, std::int64_t lda, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syr2( + selector.get_queue(), upper_lower, n, alpha, x, incx, y, incy, a, lda, dependencies); + return done; +} + +sycl::event scal(backend_selector selector, std::int64_t n, float alpha, + float *x, std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx, + dependencies); + return done; +} + +sycl::event scal(backend_selector selector, std::int64_t n, double alpha, + double *x, std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx, + dependencies); + return done; +} + +sycl::event scal(backend_selector selector, std::int64_t n, + std::complex alpha, std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx, + dependencies); + return done; +} + +sycl::event scal(backend_selector selector, std::int64_t n, + std::complex alpha, std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx, + dependencies); + return done; +} + +sycl::event scal(backend_selector selector, std::int64_t n, float alpha, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx, + dependencies); + return done; +} + +sycl::event scal(backend_selector selector, std::int64_t n, double alpha, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::scal(selector.get_queue(), n, alpha, x, incx, + dependencies); + return done; +} + +sycl::event trmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const float *a, std::int64_t lda, float *x, + std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trmv( + selector.get_queue(), upper_lower, trans, unit_diag, n, a, lda, x, incx, dependencies); + return done; +} + +sycl::event trmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const double *a, std::int64_t lda, double *x, + std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trmv( + selector.get_queue(), upper_lower, trans, unit_diag, n, a, lda, x, incx, dependencies); + return done; +} + +sycl::event trmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const std::complex *a, std::int64_t lda, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trmv( + selector.get_queue(), upper_lower, trans, unit_diag, n, a, lda, x, incx, dependencies); + return done; +} + +sycl::event trmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const std::complex *a, std::int64_t lda, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trmv( + selector.get_queue(), upper_lower, trans, unit_diag, n, a, lda, x, incx, dependencies); + return done; +} + +sycl::event tpmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const float *a, float *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tpmv(selector.get_queue(), upper_lower, trans, + unit_diag, n, a, x, incx, dependencies); + return done; +} + +sycl::event tpmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const double *a, double *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tpmv(selector.get_queue(), upper_lower, trans, + unit_diag, n, a, x, incx, dependencies); + return done; +} + +sycl::event tpmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const std::complex *a, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tpmv(selector.get_queue(), upper_lower, trans, + unit_diag, n, a, x, incx, dependencies); + return done; +} + +sycl::event tpmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const std::complex *a, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tpmv(selector.get_queue(), upper_lower, trans, + unit_diag, n, a, x, incx, dependencies); + return done; +} + +sycl::event spr(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, const float *x, std::int64_t incx, float *a, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::spr(selector.get_queue(), upper_lower, n, alpha, + x, incx, a, dependencies); + return done; +} + +sycl::event spr(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, const double *x, std::int64_t incx, double *a, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::spr(selector.get_queue(), upper_lower, n, alpha, + x, incx, a, dependencies); + return done; +} + +sycl::event hpmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *a, + const std::complex *x, std::int64_t incx, std::complex beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hpmv( + selector.get_queue(), upper_lower, n, alpha, a, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event hpmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *a, + const std::complex *x, std::int64_t incx, std::complex beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hpmv( + selector.get_queue(), upper_lower, n, alpha, a, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event syrk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, + float beta, float *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, dependencies); + return done; +} + +sycl::event syrk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, double alpha, const double *a, std::int64_t lda, + double beta, double *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, dependencies); + return done; +} + +sycl::event syrk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, std::complex beta, + std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, dependencies); + return done; +} + +sycl::event syrk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, std::complex beta, + std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, dependencies); + return done; +} + +sycl::event syrk_batch(backend_selector selector, uplo *upper_lower, + transpose *trans, std::int64_t *n, std::int64_t *k, float *alpha, + const float **a, std::int64_t *lda, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk_batch( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, group_count, + group_size, dependencies); + return done; +} + +sycl::event syrk_batch(backend_selector selector, uplo *upper_lower, + transpose *trans, std::int64_t *n, std::int64_t *k, double *alpha, + const double **a, std::int64_t *lda, double *beta, double **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk_batch( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, group_count, + group_size, dependencies); + return done; +} + +sycl::event syrk_batch(backend_selector selector, uplo *upper_lower, + transpose *trans, std::int64_t *n, std::int64_t *k, + std::complex *alpha, const std::complex **a, std::int64_t *lda, + std::complex *beta, std::complex **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk_batch( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, group_count, + group_size, dependencies); + return done; +} + +sycl::event syrk_batch(backend_selector selector, uplo *upper_lower, + transpose *trans, std::int64_t *n, std::int64_t *k, + std::complex *alpha, const std::complex **a, + std::int64_t *lda, std::complex *beta, std::complex **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk_batch( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, group_count, + group_size, dependencies); + return done; +} + +sycl::event syrk_batch(backend_selector selector, uplo upper_lower, + transpose trans, std::int64_t n, std::int64_t k, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float beta, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk_batch( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, stride_a, beta, c, ldc, + stride_c, batch_size, dependencies); + return done; +} + +sycl::event syrk_batch(backend_selector selector, uplo upper_lower, + transpose trans, std::int64_t n, std::int64_t k, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + double *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk_batch( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, stride_a, beta, c, ldc, + stride_c, batch_size, dependencies); + return done; +} + +sycl::event syrk_batch(backend_selector selector, uplo upper_lower, + transpose trans, std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, std::complex *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk_batch( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, stride_a, beta, c, ldc, + stride_c, batch_size, dependencies); + return done; +} + +sycl::event syrk_batch(backend_selector selector, uplo upper_lower, + transpose trans, std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, std::complex *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syrk_batch( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, stride_a, beta, c, ldc, + stride_c, batch_size, dependencies); + return done; +} + +sycl::event her2(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + std::int64_t lda, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::her2( + selector.get_queue(), upper_lower, n, alpha, x, incx, y, incy, a, lda, dependencies); + return done; +} + +sycl::event her2(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + std::int64_t lda, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::her2( + selector.get_queue(), upper_lower, n, alpha, x, incx, y, incy, a, lda, dependencies); + return done; +} + +sycl::event hbmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::int64_t k, std::complex alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + std::complex beta, std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::hbmv(selector.get_queue(), upper_lower, n, k, alpha, a, + lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event hbmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::int64_t k, std::complex alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + std::complex beta, std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::hbmv(selector.get_queue(), upper_lower, n, k, alpha, a, + lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event rot(backend_selector selector, std::int64_t n, + std::complex *x, std::int64_t incx, std::complex *y, + std::int64_t incy, float c, float s, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::rot(selector.get_queue(), n, x, incx, y, incy, + c, s, dependencies); + return done; +} + +sycl::event rot(backend_selector selector, std::int64_t n, + std::complex *x, std::int64_t incx, std::complex *y, + std::int64_t incy, double c, double s, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::rot(selector.get_queue(), n, x, incx, y, incy, + c, s, dependencies); + return done; +} + +sycl::event rot(backend_selector selector, std::int64_t n, float *x, + std::int64_t incx, float *y, std::int64_t incy, float c, float s, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::rot(selector.get_queue(), n, x, incx, y, incy, + c, s, dependencies); + return done; +} + +sycl::event rot(backend_selector selector, std::int64_t n, double *x, + std::int64_t incx, double *y, std::int64_t incy, double c, double s, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::rot(selector.get_queue(), n, x, incx, y, incy, + c, s, dependencies); + return done; +} + +sycl::event axpy(backend_selector selector, std::int64_t n, float alpha, + const float *x, std::int64_t incx, float *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy(selector.get_queue(), n, alpha, x, incx, y, + incy, dependencies); + return done; +} + +sycl::event axpy(backend_selector selector, std::int64_t n, double alpha, + const double *x, std::int64_t incx, double *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy(selector.get_queue(), n, alpha, x, incx, y, + incy, dependencies); + return done; +} + +sycl::event axpy(backend_selector selector, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy(selector.get_queue(), n, alpha, x, incx, y, + incy, dependencies); + return done; +} + +sycl::event axpy(backend_selector selector, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy(selector.get_queue(), n, alpha, x, incx, y, + incy, dependencies); + return done; +} + +sycl::event axpy_batch(backend_selector selector, std::int64_t *n, float *alpha, + const float **x, std::int64_t *incx, float **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy_batch( + selector.get_queue(), n, alpha, x, incx, y, incy, group_count, group_size, dependencies); + return done; +} + +sycl::event axpy_batch(backend_selector selector, std::int64_t *n, double *alpha, + const double **x, std::int64_t *incx, double **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy_batch( + selector.get_queue(), n, alpha, x, incx, y, incy, group_count, group_size, dependencies); + return done; +} + +sycl::event axpy_batch(backend_selector selector, std::int64_t *n, + std::complex *alpha, const std::complex **x, + std::int64_t *incx, std::complex **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy_batch( + selector.get_queue(), n, alpha, x, incx, y, incy, group_count, group_size, dependencies); + return done; +} + +sycl::event axpy_batch(backend_selector selector, std::int64_t *n, + std::complex *alpha, const std::complex **x, + std::int64_t *incx, std::complex **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy_batch( + selector.get_queue(), n, alpha, x, incx, y, incy, group_count, group_size, dependencies); + return done; +} + +sycl::event axpy_batch(backend_selector selector, std::int64_t n, float alpha, + const float *x, std::int64_t incx, std::int64_t stridex, float *y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy_batch(selector.get_queue(), n, alpha, x, + incx, stridex, y, incy, stridey, + batch_size, dependencies); + return done; +} + +sycl::event axpy_batch(backend_selector selector, std::int64_t n, double alpha, + const double *x, std::int64_t incx, std::int64_t stridex, double *y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy_batch(selector.get_queue(), n, alpha, x, + incx, stridex, y, incy, stridey, + batch_size, dependencies); + return done; +} + +sycl::event axpy_batch(backend_selector selector, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + std::int64_t stridex, std::complex *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy_batch(selector.get_queue(), n, alpha, x, + incx, stridex, y, incy, stridey, + batch_size, dependencies); + return done; +} + +sycl::event axpy_batch(backend_selector selector, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + std::int64_t stridex, std::complex *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpy_batch(selector.get_queue(), n, alpha, x, + incx, stridex, y, incy, stridey, + batch_size, dependencies); + return done; +} + +sycl::event axpby(backend_selector selector, std::int64_t n, float alpha, + const float *x, std::int64_t incx, const float beta, float *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpby(selector.get_queue(), n, alpha, x, incx, + beta, y, incy, dependencies); + return done; +} + +sycl::event axpby(backend_selector selector, std::int64_t n, double alpha, + const double *x, std::int64_t incx, const double beta, double *y, + std::int64_t incy, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpby(selector.get_queue(), n, alpha, x, incx, + beta, y, incy, dependencies); + return done; +} + +sycl::event axpby(backend_selector selector, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex beta, std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpby(selector.get_queue(), n, alpha, x, incx, + beta, y, incy, dependencies); + return done; +} + +sycl::event axpby(backend_selector selector, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex beta, std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::axpby(selector.get_queue(), n, alpha, x, incx, + beta, y, incy, dependencies); + return done; +} + +sycl::event gerc(backend_selector selector, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + std::int64_t lda, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gerc(selector.get_queue(), m, n, alpha, x, incx, + y, incy, a, lda, dependencies); + return done; +} + +sycl::event gerc(backend_selector selector, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + std::int64_t lda, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gerc(selector.get_queue(), m, n, alpha, x, incx, + y, incy, a, lda, dependencies); + return done; +} + +sycl::event syr2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, + const float *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syr2k(selector.get_queue(), upper_lower, trans, + n, k, alpha, a, lda, b, ldb, beta, c, ldc, + dependencies); + return done; +} + +sycl::event syr2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, double alpha, const double *a, std::int64_t lda, + const double *b, std::int64_t ldb, double beta, double *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syr2k(selector.get_queue(), upper_lower, trans, + n, k, alpha, a, lda, b, ldb, beta, c, ldc, + dependencies); + return done; +} + +sycl::event syr2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syr2k(selector.get_queue(), upper_lower, trans, + n, k, alpha, a, lda, b, ldb, beta, c, ldc, + dependencies); + return done; +} + +sycl::event syr2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syr2k(selector.get_queue(), upper_lower, trans, + n, k, alpha, a, lda, b, ldb, beta, c, ldc, + dependencies); + return done; +} + +sycl::event gemv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, const float *x, + std::int64_t incx, float beta, float *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv( + selector.get_queue(), trans, m, n, alpha, a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event gemv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, const double *x, + std::int64_t incx, double beta, double *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv( + selector.get_queue(), trans, m, n, alpha, a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event gemv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + std::complex beta, std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv( + selector.get_queue(), trans, m, n, alpha, a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event gemv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + std::complex beta, std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv( + selector.get_queue(), trans, m, n, alpha, a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event gemv_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stridea, const float *x, std::int64_t incx, + std::int64_t stridex, float beta, float *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, x, incx, stridex, beta, y, incy, + stridey, batch_size, dependencies); + return done; +} + +sycl::event gemv_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stridea, const double *x, std::int64_t incx, + std::int64_t stridex, double beta, double *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, x, incx, stridex, beta, y, incy, + stridey, batch_size, dependencies); + return done; +} + +sycl::event gemv_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stridea, + const std::complex *x, std::int64_t incx, std::int64_t stridex, + std::complex beta, std::complex *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, x, incx, stridex, beta, y, incy, + stridey, batch_size, dependencies); + return done; +} + +sycl::event gemv_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stridea, + const std::complex *x, std::int64_t incx, std::int64_t stridex, + std::complex beta, std::complex *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, x, incx, stridex, beta, y, incy, + stridey, batch_size, dependencies); + return done; +} + +sycl::event gemv_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, float *alpha, const float **a, + std::int64_t *lda, const float **x, std::int64_t *incx, float *beta, + float **y, std::int64_t *incy, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, x, incx, beta, y, incy, group_count, + group_size, dependencies); + return done; +} + +sycl::event gemv_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, double *alpha, const double **a, + std::int64_t *lda, const double **x, std::int64_t *incx, double *beta, + double **y, std::int64_t *incy, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, x, incx, beta, y, incy, group_count, + group_size, dependencies); + return done; +} + +sycl::event gemv_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, std::complex *alpha, + const std::complex **a, std::int64_t *lda, + const std::complex **x, std::int64_t *incx, std::complex *beta, + std::complex **y, std::int64_t *incy, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, x, incx, beta, y, incy, group_count, + group_size, dependencies); + return done; +} + +sycl::event gemv_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, std::complex *alpha, + const std::complex **a, std::int64_t *lda, + const std::complex **x, std::int64_t *incx, + std::complex *beta, std::complex **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemv_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, x, incx, beta, y, incy, group_count, + group_size, dependencies); + return done; +} + +sycl::event dgmm_batch(backend_selector selector, side left_right, + std::int64_t m, std::int64_t n, const float *a, std::int64_t lda, + std::int64_t stridea, const float *x, std::int64_t incx, + std::int64_t stridex, float *c, std::int64_t ldc, std::int64_t stridec, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dgmm_batch( + selector.get_queue(), left_right, m, n, a, lda, stridea, x, incx, stridex, c, ldc, stridec, + batch_size, dependencies); + return done; +} + +sycl::event dgmm_batch(backend_selector selector, side left_right, + std::int64_t m, std::int64_t n, const double *a, std::int64_t lda, + std::int64_t stridea, const double *x, std::int64_t incx, + std::int64_t stridex, double *c, std::int64_t ldc, std::int64_t stridec, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dgmm_batch( + selector.get_queue(), left_right, m, n, a, lda, stridea, x, incx, stridex, c, ldc, stridec, + batch_size, dependencies); + return done; +} + +sycl::event dgmm_batch(backend_selector selector, side left_right, + std::int64_t m, std::int64_t n, const std::complex *a, + std::int64_t lda, std::int64_t stridea, const std::complex *x, + std::int64_t incx, std::int64_t stridex, std::complex *c, + std::int64_t ldc, std::int64_t stridec, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dgmm_batch( + selector.get_queue(), left_right, m, n, a, lda, stridea, x, incx, stridex, c, ldc, stridec, + batch_size, dependencies); + return done; +} + +sycl::event dgmm_batch(backend_selector selector, side left_right, + std::int64_t m, std::int64_t n, const std::complex *a, + std::int64_t lda, std::int64_t stridea, const std::complex *x, + std::int64_t incx, std::int64_t stridex, std::complex *c, + std::int64_t ldc, std::int64_t stridec, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dgmm_batch( + selector.get_queue(), left_right, m, n, a, lda, stridea, x, incx, stridex, c, ldc, stridec, + batch_size, dependencies); + return done; +} + +sycl::event dgmm_batch(backend_selector selector, side *left_right, + std::int64_t *m, std::int64_t *n, const float **a, std::int64_t *lda, + const float **x, std::int64_t *incx, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dgmm_batch( + selector.get_queue(), left_right, m, n, a, lda, x, incx, c, ldc, group_count, group_size, + dependencies); + return done; +} + +sycl::event dgmm_batch(backend_selector selector, side *left_right, + std::int64_t *m, std::int64_t *n, const double **a, std::int64_t *lda, + const double **x, std::int64_t *incx, double **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dgmm_batch( + selector.get_queue(), left_right, m, n, a, lda, x, incx, c, ldc, group_count, group_size, + dependencies); + return done; +} + +sycl::event dgmm_batch(backend_selector selector, side *left_right, + std::int64_t *m, std::int64_t *n, const std::complex **a, + std::int64_t *lda, const std::complex **x, std::int64_t *incx, + std::complex **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dgmm_batch( + selector.get_queue(), left_right, m, n, a, lda, x, incx, c, ldc, group_count, group_size, + dependencies); + return done; +} + +sycl::event dgmm_batch(backend_selector selector, side *left_right, + std::int64_t *m, std::int64_t *n, const std::complex **a, + std::int64_t *lda, const std::complex **x, std::int64_t *incx, + std::complex **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dgmm_batch( + selector.get_queue(), left_right, m, n, a, lda, x, incx, c, ldc, group_count, group_size, + dependencies); + return done; +} + +sycl::event her(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, const std::complex *x, std::int64_t incx, + std::complex *a, std::int64_t lda, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::her(selector.get_queue(), upper_lower, n, alpha, + x, incx, a, lda, dependencies); + return done; +} + +sycl::event her(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, const std::complex *x, std::int64_t incx, + std::complex *a, std::int64_t lda, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::her(selector.get_queue(), upper_lower, n, alpha, + x, incx, a, lda, dependencies); + return done; +} + +sycl::event hpr(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, const std::complex *x, std::int64_t incx, + std::complex *a, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hpr(selector.get_queue(), upper_lower, n, alpha, + x, incx, a, dependencies); + return done; +} + +sycl::event hpr(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, const std::complex *x, std::int64_t incx, + std::complex *a, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hpr(selector.get_queue(), upper_lower, n, alpha, + x, incx, a, dependencies); + return done; +} + +sycl::event iamin(backend_selector selector, std::int64_t n, const float *x, + std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::iamin(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event iamin(backend_selector selector, std::int64_t n, const double *x, + std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::iamin(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event iamin(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::iamin(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event iamin(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::iamin(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + sycl::half *alpha, const sycl::half **a, std::int64_t *lda, + const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const float **a, std::int64_t *lda, const float **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + double *alpha, const double **a, std::int64_t *lda, const double **b, + std::int64_t *ldb, double *beta, double **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + std::complex *alpha, const std::complex **a, std::int64_t *lda, + const std::complex **b, std::int64_t *ldb, std::complex *beta, + std::complex **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + std::complex *alpha, const std::complex **a, + std::int64_t *lda, const std::complex **b, std::int64_t *ldb, + std::complex *beta, std::complex **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + sycl::half alpha, const sycl::half *a, std::int64_t lda, + std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, + const float *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + double alpha, const double *a, std::int64_t lda, std::int64_t stride_a, + const double *b, std::int64_t ldb, std::int64_t stride_b, double beta, + double *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, const std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::complex beta, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, const std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::complex beta, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, const float *a, const float *x, std::int64_t incx, float beta, + float *y, std::int64_t incy, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::spmv( + selector.get_queue(), upper_lower, n, alpha, a, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event spmv(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, const double *a, const double *x, std::int64_t incx, double beta, + double *y, std::int64_t incy, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::spmv( + selector.get_queue(), upper_lower, n, alpha, a, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event swap(backend_selector selector, std::int64_t n, float *x, + std::int64_t incx, float *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::swap(selector.get_queue(), n, x, incx, y, incy, + dependencies); + return done; +} + +sycl::event swap(backend_selector selector, std::int64_t n, double *x, + std::int64_t incx, double *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::swap(selector.get_queue(), n, x, incx, y, incy, + dependencies); + return done; +} + +sycl::event swap(backend_selector selector, std::int64_t n, + std::complex *x, std::int64_t incx, std::complex *y, + std::int64_t incy, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::swap(selector.get_queue(), n, x, incx, y, incy, + dependencies); + return done; +} + +sycl::event swap(backend_selector selector, std::int64_t n, + std::complex *x, std::int64_t incx, std::complex *y, + std::int64_t incy, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::swap(selector.get_queue(), n, x, incx, y, incy, + dependencies); + return done; +} + +sycl::event geru(backend_selector selector, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + std::int64_t lda, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::geru(selector.get_queue(), m, n, alpha, x, incx, + y, incy, a, lda, dependencies); + return done; +} + +sycl::event geru(backend_selector selector, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + std::int64_t lda, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::geru(selector.get_queue(), m, n, alpha, x, incx, + y, incy, a, lda, dependencies); + return done; +} + +sycl::event nrm2(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, float *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::nrm2(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event nrm2(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, double *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::nrm2(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event nrm2(backend_selector selector, std::int64_t n, const float *x, + std::int64_t incx, float *result, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::nrm2(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event nrm2(backend_selector selector, std::int64_t n, const double *x, + std::int64_t incx, double *result, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::nrm2(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, + std::int64_t lda, const float *b, std::int64_t ldb, float beta, float *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, dependencies); + return done; +} + +sycl::event gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, double alpha, const double *a, + std::int64_t lda, const double *b, std::int64_t ldb, double beta, double *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, dependencies); + return done; +} + +sycl::event gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, dependencies); + return done; +} + +sycl::event gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, dependencies); + return done; +} + +sycl::event gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, + const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, + sycl::half beta, sycl::half *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, dependencies); + return done; +} + +sycl::event gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, dependencies); + return done; +} + +sycl::event gemm(backend_selector selector, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const bfloat16 *a, + std::int64_t lda, const bfloat16 *b, std::int64_t ldb, float beta, float *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gemm(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, b, ldb, beta, c, ldc, dependencies); + return done; +} + +sycl::event gemm_bias(backend_selector selector, transpose transa, + transpose transb, offset offsetc, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, + std::int8_t ao, const std::uint8_t *b, std::int64_t ldb, std::uint8_t bo, + float beta, std::int32_t *c, std::int64_t ldc, const std::int32_t *co, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_bias( + selector.get_queue(), transa, transb, offsetc, m, n, k, alpha, a, lda, ao, b, ldb, bo, beta, + c, ldc, co, dependencies); + return done; +} + +sycl::event gemm_bias(backend_selector selector, transpose transa, + transpose transb, offset offsetc, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, + std::int8_t ao, const std::int8_t *b, std::int64_t ldb, std::int8_t bo, + float beta, std::int32_t *c, std::int64_t ldc, const std::int32_t *co, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_bias( + selector.get_queue(), transa, transb, offsetc, m, n, k, alpha, a, lda, ao, b, ldb, bo, beta, + c, ldc, co, dependencies); + return done; +} + +sycl::event gemm_bias(backend_selector selector, transpose transa, + transpose transb, offset offsetc, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::uint8_t *a, std::int64_t lda, + std::uint8_t ao, const std::int8_t *b, std::int64_t ldb, std::int8_t bo, + float beta, std::int32_t *c, std::int64_t ldc, const std::int32_t *co, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_bias( + selector.get_queue(), transa, transb, offsetc, m, n, k, alpha, a, lda, ao, b, ldb, bo, beta, + c, ldc, co, dependencies); + return done; +} + +sycl::event gemm_bias(backend_selector selector, transpose transa, + transpose transb, offset offsetc, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::uint8_t *a, std::int64_t lda, + std::uint8_t ao, const std::uint8_t *b, std::int64_t ldb, std::uint8_t bo, + float beta, std::int32_t *c, std::int64_t ldc, const std::int32_t *co, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemm_bias( + selector.get_queue(), transa, transb, offsetc, m, n, k, alpha, a, lda, ao, b, ldb, bo, beta, + c, ldc, co, dependencies); + return done; +} + +sycl::event herk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, float alpha, const std::complex *a, + std::int64_t lda, float beta, std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::herk( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, dependencies); + return done; +} + +sycl::event herk(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, double alpha, const std::complex *a, + std::int64_t lda, double beta, std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::herk( + selector.get_queue(), upper_lower, trans, n, k, alpha, a, lda, beta, c, ldc, dependencies); + return done; +} + +sycl::event ger(backend_selector selector, std::int64_t m, std::int64_t n, + float alpha, const float *x, std::int64_t incx, const float *y, std::int64_t incy, + float *a, std::int64_t lda, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::ger(selector.get_queue(), m, n, alpha, x, incx, + y, incy, a, lda, dependencies); + return done; +} + +sycl::event ger(backend_selector selector, std::int64_t m, std::int64_t n, + double alpha, const double *x, std::int64_t incx, const double *y, + std::int64_t incy, double *a, std::int64_t lda, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::ger(selector.get_queue(), m, n, alpha, x, incx, + y, incy, a, lda, dependencies); + return done; +} + +sycl::event trsm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, float *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm(selector.get_queue(), left_right, + upper_lower, trans, unit_diag, m, n, alpha, + a, lda, b, ldb, dependencies); + return done; +} + +sycl::event trsm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, double *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm(selector.get_queue(), left_right, + upper_lower, trans, unit_diag, m, n, alpha, + a, lda, b, ldb, dependencies); + return done; +} + +sycl::event trsm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::complex *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm(selector.get_queue(), left_right, + upper_lower, trans, unit_diag, m, n, alpha, + a, lda, b, ldb, dependencies); + return done; +} + +sycl::event trsm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::complex *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm(selector.get_queue(), left_right, + upper_lower, trans, unit_diag, m, n, alpha, + a, lda, b, ldb, dependencies); + return done; +} + +sycl::event trsm_batch(backend_selector selector, side left_right, + uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, float *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm_batch( + selector.get_queue(), left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size, dependencies); + return done; +} + +sycl::event trsm_batch(backend_selector selector, side left_right, + uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stride_a, double *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm_batch( + selector.get_queue(), left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size, dependencies); + return done; +} + +sycl::event trsm_batch(backend_selector selector, side left_right, + uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm_batch( + selector.get_queue(), left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size, dependencies); + return done; +} + +sycl::event trsm_batch(backend_selector selector, side left_right, + uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm_batch( + selector.get_queue(), left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, + stride_a, b, ldb, stride_b, batch_size, dependencies); + return done; +} + +sycl::event trsm_batch(backend_selector selector, side *left_right, + uplo *upper_lower, transpose *trans, diag *unit_diag, std::int64_t *m, + std::int64_t *n, float *alpha, const float **a, std::int64_t *lda, float **b, + std::int64_t *ldb, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm_batch( + selector.get_queue(), left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb, group_count, group_size, dependencies); + return done; +} + +sycl::event trsm_batch(backend_selector selector, side *left_right, + uplo *upper_lower, transpose *trans, diag *unit_diag, std::int64_t *m, + std::int64_t *n, double *alpha, const double **a, std::int64_t *lda, + double **b, std::int64_t *ldb, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm_batch( + selector.get_queue(), left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb, group_count, group_size, dependencies); + return done; +} + +sycl::event trsm_batch(backend_selector selector, side *left_right, + uplo *upper_lower, transpose *trans, diag *unit_diag, std::int64_t *m, + std::int64_t *n, std::complex *alpha, const std::complex **a, + std::int64_t *lda, std::complex **b, std::int64_t *ldb, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm_batch( + selector.get_queue(), left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb, group_count, group_size, dependencies); + return done; +} + +sycl::event trsm_batch(backend_selector selector, side *left_right, + uplo *upper_lower, transpose *trans, diag *unit_diag, std::int64_t *m, + std::int64_t *n, std::complex *alpha, const std::complex **a, + std::int64_t *lda, std::complex **b, std::int64_t *ldb, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsm_batch( + selector.get_queue(), left_right, upper_lower, trans, unit_diag, m, n, alpha, a, lda, b, + ldb, group_count, group_size, dependencies); + return done; +} + +sycl::event dotu(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, const std::complex *y, + std::int64_t incy, std::complex *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dotu(selector.get_queue(), n, x, incx, y, incy, + result, dependencies); + return done; +} + +sycl::event dotu(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, const std::complex *y, + std::int64_t incy, std::complex *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dotu(selector.get_queue(), n, x, incx, y, incy, + result, dependencies); + return done; +} + +sycl::event hemm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hemm(selector.get_queue(), left_right, + upper_lower, m, n, alpha, a, lda, b, ldb, + beta, c, ldc, dependencies); + return done; +} + +sycl::event hemm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hemm(selector.get_queue(), left_right, + upper_lower, m, n, alpha, a, lda, b, ldb, + beta, c, ldc, dependencies); + return done; +} + +sycl::event hpr2(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hpr2(selector.get_queue(), upper_lower, n, + alpha, x, incx, y, incy, a, dependencies); + return done; +} + +sycl::event hpr2(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hpr2(selector.get_queue(), upper_lower, n, + alpha, x, incx, y, incy, a, dependencies); + return done; +} + +sycl::event gbmv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::int64_t kl, std::int64_t ku, float alpha, const float *a, + std::int64_t lda, const float *x, std::int64_t incx, float beta, float *y, + std::int64_t incy, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gbmv(selector.get_queue(), trans, m, n, kl, ku, alpha, + a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event gbmv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::int64_t kl, std::int64_t ku, double alpha, const double *a, + std::int64_t lda, const double *x, std::int64_t incx, double beta, double *y, + std::int64_t incy, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gbmv(selector.get_queue(), trans, m, n, kl, ku, alpha, + a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event gbmv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::int64_t kl, std::int64_t ku, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *x, + std::int64_t incx, std::complex beta, std::complex *y, + std::int64_t incy, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gbmv(selector.get_queue(), trans, m, n, kl, ku, alpha, + a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event gbmv(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::int64_t kl, std::int64_t ku, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *x, + std::int64_t incx, std::complex beta, std::complex *y, + std::int64_t incy, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::gbmv(selector.get_queue(), trans, m, n, kl, ku, alpha, + a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event tbmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, const float *a, std::int64_t lda, + float *x, std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tbmv( + selector.get_queue(), upper_lower, trans, unit_diag, n, k, a, lda, x, incx, dependencies); + return done; +} + +sycl::event tbmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, const double *a, std::int64_t lda, + double *x, std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tbmv( + selector.get_queue(), upper_lower, trans, unit_diag, n, k, a, lda, x, incx, dependencies); + return done; +} + +sycl::event tbmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, const std::complex *a, + std::int64_t lda, std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tbmv( + selector.get_queue(), upper_lower, trans, unit_diag, n, k, a, lda, x, incx, dependencies); + return done; +} + +sycl::event tbmv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, const std::complex *a, + std::int64_t lda, std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tbmv( + selector.get_queue(), upper_lower, trans, unit_diag, n, k, a, lda, x, incx, dependencies); + return done; +} + +sycl::event symm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, float alpha, const float *a, std::int64_t lda, + const float *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::symm(selector.get_queue(), left_right, + upper_lower, m, n, alpha, a, lda, b, ldb, + beta, c, ldc, dependencies); + return done; +} + +sycl::event symm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, double alpha, const double *a, std::int64_t lda, + const double *b, std::int64_t ldb, double beta, double *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::symm(selector.get_queue(), left_right, + upper_lower, m, n, alpha, a, lda, b, ldb, + beta, c, ldc, dependencies); + return done; +} + +sycl::event symm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::symm(selector.get_queue(), left_right, + upper_lower, m, n, alpha, a, lda, b, ldb, + beta, c, ldc, dependencies); + return done; +} + +sycl::event symm(backend_selector selector, side left_right, uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::symm(selector.get_queue(), left_right, + upper_lower, m, n, alpha, a, lda, b, ldb, + beta, c, ldc, dependencies); + return done; +} + +sycl::event dotc(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, const std::complex *y, + std::int64_t incy, std::complex *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dotc(selector.get_queue(), n, x, incx, y, incy, + result, dependencies); + return done; +} + +sycl::event dotc(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, const std::complex *y, + std::int64_t incy, std::complex *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dotc(selector.get_queue(), n, x, incx, y, incy, + result, dependencies); + return done; +} + +sycl::event syr(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, const float *x, std::int64_t incx, float *a, std::int64_t lda, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syr(selector.get_queue(), upper_lower, n, alpha, + x, incx, a, lda, dependencies); + return done; +} + +sycl::event syr(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, const double *x, std::int64_t incx, double *a, std::int64_t lda, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::syr(selector.get_queue(), upper_lower, n, alpha, + x, incx, a, lda, dependencies); + return done; +} + +sycl::event trmm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, float *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trmm(selector.get_queue(), left_right, + upper_lower, trans, unit_diag, m, n, alpha, + a, lda, b, ldb, dependencies); + return done; +} + +sycl::event trmm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, double *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trmm(selector.get_queue(), left_right, + upper_lower, trans, unit_diag, m, n, alpha, + a, lda, b, ldb, dependencies); + return done; +} + +sycl::event trmm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::complex *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trmm(selector.get_queue(), left_right, + upper_lower, trans, unit_diag, m, n, alpha, + a, lda, b, ldb, dependencies); + return done; +} + +sycl::event trmm(backend_selector selector, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::complex *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trmm(selector.get_queue(), left_right, + upper_lower, trans, unit_diag, m, n, alpha, + a, lda, b, ldb, dependencies); + return done; +} + +sycl::event rotmg(backend_selector selector, float *d1, float *d2, float *x1, + float y1, float *param, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::rotmg(selector.get_queue(), d1, d2, x1, y1, + param, dependencies); + return done; +} + +sycl::event rotmg(backend_selector selector, double *d1, double *d2, double *x1, + double y1, double *param, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::rotmg(selector.get_queue(), d1, d2, x1, y1, + param, dependencies); + return done; +} + +sycl::event tpsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const float *a, float *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tpsv(selector.get_queue(), upper_lower, trans, + unit_diag, n, a, x, incx, dependencies); + return done; +} + +sycl::event tpsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const double *a, double *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tpsv(selector.get_queue(), upper_lower, trans, + unit_diag, n, a, x, incx, dependencies); + return done; +} + +sycl::event tpsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const std::complex *a, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tpsv(selector.get_queue(), upper_lower, trans, + unit_diag, n, a, x, incx, dependencies); + return done; +} + +sycl::event tpsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const std::complex *a, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tpsv(selector.get_queue(), upper_lower, trans, + unit_diag, n, a, x, incx, dependencies); + return done; +} + +sycl::event trsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const float *a, std::int64_t lda, float *x, + std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsv( + selector.get_queue(), upper_lower, trans, unit_diag, n, a, lda, x, incx, dependencies); + return done; +} + +sycl::event trsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const double *a, std::int64_t lda, double *x, + std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsv( + selector.get_queue(), upper_lower, trans, unit_diag, n, a, lda, x, incx, dependencies); + return done; +} + +sycl::event trsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const std::complex *a, std::int64_t lda, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsv( + selector.get_queue(), upper_lower, trans, unit_diag, n, a, lda, x, incx, dependencies); + return done; +} + +sycl::event trsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, const std::complex *a, std::int64_t lda, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::trsv( + selector.get_queue(), upper_lower, trans, unit_diag, n, a, lda, x, incx, dependencies); + return done; +} + +sycl::event copy(backend_selector selector, std::int64_t n, const float *x, + std::int64_t incx, float *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy(selector.get_queue(), n, x, incx, y, incy, + dependencies); + return done; +} + +sycl::event copy(backend_selector selector, std::int64_t n, const double *x, + std::int64_t incx, double *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy(selector.get_queue(), n, x, incx, y, incy, + dependencies); + return done; +} + +sycl::event copy(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, std::complex *y, + std::int64_t incy, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy(selector.get_queue(), n, x, incx, y, incy, + dependencies); + return done; +} + +sycl::event copy(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, std::complex *y, + std::int64_t incy, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy(selector.get_queue(), n, x, incx, y, incy, + dependencies); + return done; +} + +sycl::event copy_batch(backend_selector selector, std::int64_t *n, + const float **x, std::int64_t *incx, float **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy_batch( + selector.get_queue(), n, x, incx, y, incy, group_count, group_size, dependencies); + return done; +} + +sycl::event copy_batch(backend_selector selector, std::int64_t *n, + const double **x, std::int64_t *incx, double **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy_batch( + selector.get_queue(), n, x, incx, y, incy, group_count, group_size, dependencies); + return done; +} + +sycl::event copy_batch(backend_selector selector, std::int64_t *n, + const std::complex **x, std::int64_t *incx, std::complex **y, + std::int64_t *incy, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy_batch( + selector.get_queue(), n, x, incx, y, incy, group_count, group_size, dependencies); + return done; +} + +sycl::event copy_batch(backend_selector selector, std::int64_t *n, + const std::complex **x, std::int64_t *incx, std::complex **y, + std::int64_t *incy, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy_batch( + selector.get_queue(), n, x, incx, y, incy, group_count, group_size, dependencies); + return done; +} + +sycl::event copy_batch(backend_selector selector, std::int64_t n, const float *x, + std::int64_t incx, std::int64_t stridex, float *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy_batch( + selector.get_queue(), n, x, incx, stridex, y, incy, stridey, batch_size, dependencies); + return done; +} + +sycl::event copy_batch(backend_selector selector, std::int64_t n, + const double *x, std::int64_t incx, std::int64_t stridex, double *y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy_batch( + selector.get_queue(), n, x, incx, stridex, y, incy, stridey, batch_size, dependencies); + return done; +} + +sycl::event copy_batch(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, std::int64_t stridex, + std::complex *y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy_batch( + selector.get_queue(), n, x, incx, stridex, y, incy, stridey, batch_size, dependencies); + return done; +} + +sycl::event copy_batch(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, std::int64_t stridex, + std::complex *y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::copy_batch( + selector.get_queue(), n, x, incx, stridex, y, incy, stridey, batch_size, dependencies); + return done; +} + +sycl::event hemv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + const std::complex *x, std::int64_t incx, std::complex beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hemv( + selector.get_queue(), upper_lower, n, alpha, a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event hemv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + const std::complex *x, std::int64_t incx, std::complex beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::hemv( + selector.get_queue(), upper_lower, n, alpha, a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event gemmt(backend_selector selector, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, + std::int64_t lda, const float *b, std::int64_t ldb, float beta, float *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemmt(selector.get_queue(), upper_lower, transa, + transb, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); + return done; +} + +sycl::event gemmt(backend_selector selector, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, double alpha, const double *a, + std::int64_t lda, const double *b, std::int64_t ldb, double beta, double *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemmt(selector.get_queue(), upper_lower, transa, + transb, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); + return done; +} + +sycl::event gemmt(backend_selector selector, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemmt(selector.get_queue(), upper_lower, transa, + transb, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); + return done; +} + +sycl::event gemmt(backend_selector selector, uplo upper_lower, transpose transa, + transpose transb, std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::gemmt(selector.get_queue(), upper_lower, transa, + transb, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); + return done; +} + +sycl::event sbmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *x, + std::int64_t incx, float beta, float *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::sbmv(selector.get_queue(), upper_lower, n, k, alpha, a, + lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event sbmv(backend_selector selector, uplo upper_lower, std::int64_t n, + std::int64_t k, double alpha, const double *a, std::int64_t lda, const double *x, + std::int64_t incx, double beta, double *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::sbmv(selector.get_queue(), upper_lower, n, k, alpha, a, + lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event asum(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, float *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::asum(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event asum(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, double *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::asum(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event asum(backend_selector selector, std::int64_t n, const float *x, + std::int64_t incx, float *result, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::asum(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event asum(backend_selector selector, std::int64_t n, const double *x, + std::int64_t incx, double *result, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::asum(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event tbsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, const float *a, std::int64_t lda, + float *x, std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tbsv( + selector.get_queue(), upper_lower, trans, unit_diag, n, k, a, lda, x, incx, dependencies); + return done; +} + +sycl::event tbsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, const double *a, std::int64_t lda, + double *x, std::int64_t incx, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tbsv( + selector.get_queue(), upper_lower, trans, unit_diag, n, k, a, lda, x, incx, dependencies); + return done; +} + +sycl::event tbsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, const std::complex *a, + std::int64_t lda, std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tbsv( + selector.get_queue(), upper_lower, trans, unit_diag, n, k, a, lda, x, incx, dependencies); + return done; +} + +sycl::event tbsv(backend_selector selector, uplo upper_lower, transpose trans, + diag unit_diag, std::int64_t n, std::int64_t k, const std::complex *a, + std::int64_t lda, std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::tbsv( + selector.get_queue(), upper_lower, trans, unit_diag, n, k, a, lda, x, incx, dependencies); + return done; +} + +sycl::event spr2(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, const float *x, std::int64_t incx, const float *y, std::int64_t incy, + float *a, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::spr2(selector.get_queue(), upper_lower, n, + alpha, x, incx, y, incy, a, dependencies); + return done; +} + +sycl::event spr2(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, const double *x, std::int64_t incx, const double *y, + std::int64_t incy, double *a, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::spr2(selector.get_queue(), upper_lower, n, + alpha, x, incx, y, incy, a, dependencies); + return done; +} + +sycl::event iamax(backend_selector selector, std::int64_t n, const float *x, + std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::iamax(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event iamax(backend_selector selector, std::int64_t n, const double *x, + std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::iamax(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event iamax(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::iamax(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event iamax(backend_selector selector, std::int64_t n, + const std::complex *x, std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::iamax(selector.get_queue(), n, x, incx, result, + dependencies); + return done; +} + +sycl::event rotm(backend_selector selector, std::int64_t n, float *x, + std::int64_t incx, float *y, std::int64_t incy, float *param, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::rotm(selector.get_queue(), n, x, incx, y, incy, + param, dependencies); + return done; +} + +sycl::event rotm(backend_selector selector, std::int64_t n, double *x, + std::int64_t incx, double *y, std::int64_t incy, double *param, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::rotm(selector.get_queue(), n, x, incx, y, incy, + param, dependencies); + return done; +} + +sycl::event rotg(backend_selector selector, float *a, float *b, float *c, + float *s, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::rotg(selector.get_queue(), a, b, c, s, dependencies); + return done; +} + +sycl::event rotg(backend_selector selector, double *a, double *b, double *c, + double *s, const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::rotg(selector.get_queue(), a, b, c, s, dependencies); + return done; +} + +sycl::event rotg(backend_selector selector, std::complex *a, + std::complex *b, float *c, std::complex *s, + const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::rotg(selector.get_queue(), a, b, c, s, dependencies); + return done; +} + +sycl::event rotg(backend_selector selector, std::complex *a, + std::complex *b, double *c, std::complex *s, + const std::vector &dependencies) { + auto done = + oneapi::mkl::blas::portblas::MAJOR::rotg(selector.get_queue(), a, b, c, s, dependencies); + return done; +} + +sycl::event sdsdot(backend_selector selector, std::int64_t n, float sb, + const float *x, std::int64_t incx, const float *y, std::int64_t incy, + float *result, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::sdsdot(selector.get_queue(), n, sb, x, incx, y, + incy, result, dependencies); + return done; +} + +sycl::event her2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, float beta, std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::her2k(selector.get_queue(), upper_lower, trans, + n, k, alpha, a, lda, b, ldb, beta, c, ldc, + dependencies); + return done; +} + +sycl::event her2k(backend_selector selector, uplo upper_lower, transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, double beta, std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::her2k(selector.get_queue(), upper_lower, trans, + n, k, alpha, a, lda, b, ldb, beta, c, ldc, + dependencies); + return done; +} + +sycl::event dot(backend_selector selector, std::int64_t n, const float *x, + std::int64_t incx, const float *y, std::int64_t incy, float *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dot(selector.get_queue(), n, x, incx, y, incy, + result, dependencies); + return done; +} + +sycl::event dot(backend_selector selector, std::int64_t n, const double *x, + std::int64_t incx, const double *y, std::int64_t incy, double *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dot(selector.get_queue(), n, x, incx, y, incy, + result, dependencies); + return done; +} + +sycl::event dot(backend_selector selector, std::int64_t n, const float *x, + std::int64_t incx, const float *y, std::int64_t incy, double *result, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::dot(selector.get_queue(), n, x, incx, y, incy, + result, dependencies); + return done; +} + +sycl::event symv(backend_selector selector, uplo upper_lower, std::int64_t n, + float alpha, const float *a, std::int64_t lda, const float *x, std::int64_t incx, + float beta, float *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::symv( + selector.get_queue(), upper_lower, n, alpha, a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event symv(backend_selector selector, uplo upper_lower, std::int64_t n, + double alpha, const double *a, std::int64_t lda, const double *x, + std::int64_t incx, double beta, double *y, std::int64_t incy, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::symv( + selector.get_queue(), upper_lower, n, alpha, a, lda, x, incx, beta, y, incy, dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, stride_a, b, ldb, stride_b, batch_size, + dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, double alpha, double *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + std::complex *ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, stride, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatadd_batch(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatadd_batch( + selector.get_queue(), transa, transb, m, n, alpha, a, lda, stride_a, beta, b, ldb, stride_b, + c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event omatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, float *b, + std::int64_t ldb, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy(selector.get_queue(), trans, m, n, + alpha, a, lda, b, ldb, dependencies); + return done; +} + +sycl::event omatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, double *b, + std::int64_t ldb, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy(selector.get_queue(), trans, m, n, + alpha, a, lda, b, ldb, dependencies); + return done; +} + +sycl::event omatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy(selector.get_queue(), trans, m, n, + alpha, a, lda, b, ldb, dependencies); + return done; +} + +sycl::event omatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy(selector.get_queue(), trans, m, n, + alpha, a, lda, b, ldb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stridea, double *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy(selector.get_queue(), trans, m, n, + alpha, ab, lda, ldb, dependencies); + return done; +} + +sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, double *ab, std::int64_t lda, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy(selector.get_queue(), trans, m, n, + alpha, ab, lda, ldb, dependencies); + return done; +} + +sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy(selector.get_queue(), trans, m, n, + alpha, ab, lda, ldb, dependencies); + return done; +} + +sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy(selector.get_queue(), trans, m, n, + alpha, ab, lda, ldb, dependencies); + return done; +} + +sycl::event omatadd(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, float alpha, const float *a, + std::int64_t lda, float beta, const float *b, std::int64_t ldb, float *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatadd(selector.get_queue(), transa, transb, m, + n, alpha, a, lda, beta, b, ldb, c, ldc, + dependencies); + return done; +} + +sycl::event omatadd(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, double beta, const double *b, std::int64_t ldb, double *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatadd(selector.get_queue(), transa, transb, m, + n, alpha, a, lda, beta, b, ldb, c, ldc, + dependencies); + return done; +} + +sycl::event omatadd(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::complex beta, + const std::complex *b, std::int64_t ldb, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatadd(selector.get_queue(), transa, transb, m, + n, alpha, a, lda, beta, b, ldb, c, ldc, + dependencies); + return done; +} + +sycl::event omatadd(backend_selector selector, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::complex beta, + const std::complex *b, std::int64_t ldb, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatadd(selector.get_queue(), transa, transb, m, + n, alpha, a, lda, beta, b, ldb, c, ldc, + dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, float *alpha, const float **a, + std::int64_t *lda, float **b, std::int64_t *ldb, + std::int64_t group_count, std::int64_t *groupsize, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, b, ldb, group_count, groupsize, + dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, double *alpha, const double **a, + std::int64_t *lda, double **b, std::int64_t *ldb, + std::int64_t group_count, std::int64_t *groupsize, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, b, ldb, group_count, groupsize, + dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, std::complex *alpha, + const std::complex **a, std::int64_t *lda, + std::complex **b, std::int64_t *ldb, std::int64_t group_count, + std::int64_t *groupsize, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, b, ldb, group_count, groupsize, + dependencies); + return done; +} + +sycl::event omatcopy_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, std::complex *alpha, + const std::complex **a, std::int64_t *lda, + std::complex **b, std::int64_t *ldb, std::int64_t group_count, + std::int64_t *groupsize, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::omatcopy_batch( + selector.get_queue(), trans, m, n, alpha, a, lda, b, ldb, group_count, groupsize, + dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, float *alpha, float **ab, + std::int64_t *lda, std::int64_t *ldb, std::int64_t group_count, + std::int64_t *groupsize, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, group_count, groupsize, + dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, double *alpha, double **ab, + std::int64_t *lda, std::int64_t *ldb, std::int64_t group_count, + std::int64_t *groupsize, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, group_count, groupsize, + dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, std::complex *alpha, + std::complex **ab, std::int64_t *lda, std::int64_t *ldb, + std::int64_t group_count, std::int64_t *groupsize, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, group_count, groupsize, + dependencies); + return done; +} + +sycl::event imatcopy_batch(backend_selector selector, transpose *trans, + std::int64_t *m, std::int64_t *n, std::complex *alpha, + std::complex **ab, std::int64_t *lda, std::int64_t *ldb, + std::int64_t group_count, std::int64_t *groupsize, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::portblas::MAJOR::imatcopy_batch( + selector.get_queue(), trans, m, n, alpha, ab, lda, ldb, group_count, groupsize, + dependencies); + return done; +} diff --git a/include/oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp b/include/oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp new file mode 100644 index 000000000..c8d47d742 --- /dev/null +++ b/include/oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp @@ -0,0 +1,61 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_BLAS_PORTBLAS_HPP_ +#define _ONEMKL_BLAS_PORTBLAS_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include + +#include "oneapi/mkl/types.hpp" + +#include "oneapi/mkl/detail/export.hpp" + +namespace oneapi { +namespace mkl { + +using oneapi::mkl::transpose; +using oneapi::mkl::uplo; +using oneapi::mkl::side; +using oneapi::mkl::diag; +using oneapi::mkl::offset; + +namespace blas { +namespace portblas { +namespace column_major { + +#include "oneapi/mkl/blas/detail/onemkl_blas_backends.hxx" + +} //namespace column_major +namespace row_major { + +#include "oneapi/mkl/blas/detail/onemkl_blas_backends.hxx" + +} //namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi + +#endif // _ONEMKL_BLAS_PORTBLAS_HPP_ diff --git a/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx b/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx index 24d7beaaf..bc86929b0 100644 --- a/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx +++ b/include/oneapi/mkl/blas/detail/rocblas/blas_ct.hxx @@ -181,6 +181,36 @@ void gemm_batch(backend_selector selector, transpose transa, t c, ldc, stride_c, batch_size); } +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + oneapi::mkl::blas::rocblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + oneapi::mkl::blas::rocblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + +void gemm_batch(backend_selector selector, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, float alpha, sycl::buffer &a, + int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, + int64_t stride_b, float beta, sycl::buffer &c, int64_t ldc, + int64_t stride_c, int64_t batch_size) { + oneapi::mkl::blas::rocblas::MAJOR::gemm_batch(selector.get_queue(), transa, transb, m, n, k, + alpha, a, lda, stride_a, b, ldb, stride_b, beta, + c, ldc, stride_c, batch_size); +} + void syrk(backend_selector selector, uplo upper_lower, transpose trans, int64_t n, int64_t k, float alpha, sycl::buffer &a, int64_t lda, float beta, sycl::buffer &c, int64_t ldc) { @@ -1552,6 +1582,38 @@ void omatcopy(backend_selector selector, transpose trans, std: ldb); } +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::rocblas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + oneapi::mkl::blas::rocblas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::rocblas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + +void omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + oneapi::mkl::blas::rocblas::MAJOR::omatcopy2(selector.get_queue(), trans, m, n, alpha, a, lda, + stridea, b, ldb, strideb); +} + void imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb) { @@ -2506,6 +2568,39 @@ sycl::event gemm_batch(backend_selector selector, transpose *t return done; } +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, + const sycl::half **a, int64_t *lda, const sycl::half **b, int64_t *ldb, + float *beta, float **c, int64_t *ldc, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, + const std::int8_t **a, int64_t *lda, const std::int8_t **b, int64_t *ldb, + float *beta, float **c, int64_t *ldc, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose *transa, + transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, + const std::int8_t **a, int64_t *lda, const std::int8_t **b, int64_t *ldb, + float *beta, std::int32_t **c, int64_t *ldc, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, + group_count, group_size, dependencies); + return done; +} + sycl::event gemm_batch(backend_selector selector, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, int64_t stride_a, const float *b, int64_t ldb, @@ -2566,6 +2661,42 @@ sycl::event gemm_batch(backend_selector selector, transpose tr return done; } +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, float alpha, + const sycl::half *a, int64_t lda, int64_t stride_a, const sycl::half *b, + int64_t ldb, int64_t stride_b, float beta, float *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, float alpha, + const std::int8_t *a, int64_t lda, int64_t stride_a, const std::int8_t *b, + int64_t ldb, int64_t stride_b, float beta, float *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + +sycl::event gemm_batch(backend_selector selector, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, float alpha, + const std::int8_t *a, int64_t lda, int64_t stride_a, const std::int8_t *b, + int64_t ldb, int64_t stride_b, float beta, std::int32_t *c, int64_t ldc, + int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::gemm_batch( + selector.get_queue(), transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, + beta, c, ldc, stride_c, batch_size, dependencies); + return done; +} + sycl::event spmv(backend_selector selector, uplo upper_lower, int64_t n, float alpha, const float *a, const float *x, int64_t incx, float beta, float *y, int64_t incy, const std::vector &dependencies) { @@ -3844,6 +3975,44 @@ sycl::event omatcopy(backend_selector selector, transpose tran return done; } +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stridea, double *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + +sycl::event omatcopy2(backend_selector selector, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + auto done = oneapi::mkl::blas::rocblas::MAJOR::omatcopy2( + selector.get_queue(), trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); + return done; +} + sycl::event imatcopy(backend_selector selector, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, const std::vector &dependencies) { diff --git a/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx b/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx index c1ef299ad..70aabaaf9 100644 --- a/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx +++ b/include/oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hxx @@ -744,6 +744,24 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t sycl::half beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size); +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size); + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, @@ -878,6 +896,23 @@ void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co sycl::buffer, 1> &a, int64_t lda, sycl::buffer, 1> &b, int64_t ldb); +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb); + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb); + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, std::int64_t stridea, + sycl::buffer, 1> &b, int64_t ldb, std::int64_t strideb); + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, int64_t ldb, + std::int64_t strideb); + void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb); @@ -1831,6 +1866,24 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, sycl::half **c, int64_t *ldc, int64_t group_count, int64_t *group_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda, + const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies = {}); + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, int64_t stride_a, const float *b, int64_t ldb, int64_t stride_b, float beta, float *c, @@ -1863,6 +1916,24 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i sycl::half beta, sycl::half *c, int64_t ldc, int64_t stride_c, int64_t batch_size, const std::vector &dependencies = {}); +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a, + const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, + std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies = {}); + sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, const float *b, int64_t ldb, float beta, float *c, int64_t ldc, @@ -1991,6 +2062,24 @@ sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex *b, int64_t ldb, const std::vector &dependencies = {}); +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, std::int64_t stridea, float *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies = {}); + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, std::int64_t stridea, double *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies = {}); + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + std::int64_t stridea, std::complex *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies = {}); + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + std::int64_t stridea, std::complex *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies = {}); + sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, const std::vector &dependencies = {}); diff --git a/include/oneapi/mkl/detail/backend_selector_predicates.hpp b/include/oneapi/mkl/detail/backend_selector_predicates.hpp index ce7457f18..4ee3f3bb1 100644 --- a/include/oneapi/mkl/detail/backend_selector_predicates.hpp +++ b/include/oneapi/mkl/detail/backend_selector_predicates.hpp @@ -35,7 +35,7 @@ namespace oneapi { namespace mkl { template -inline void backend_selector_precondition(sycl::queue& queue){}; +inline void backend_selector_precondition(sycl::queue&) {} template <> inline void backend_selector_precondition(sycl::queue& queue) { diff --git a/include/oneapi/mkl/detail/backends.hpp b/include/oneapi/mkl/detail/backends.hpp index 1cd9d8207..32b7c2614 100644 --- a/include/oneapi/mkl/detail/backends.hpp +++ b/include/oneapi/mkl/detail/backends.hpp @@ -36,6 +36,10 @@ enum class backend { netlib, rocblas, rocrand, + portblas, + cufft, + rocfft, + portfft, unsupported }; @@ -46,7 +50,9 @@ static backendmap backend_map = { { backend::cublas, "cublas" }, { backend::cusolver, "cusolver" }, { backend::curand, "curand" }, { backend::netlib, "netlib" }, { backend::rocblas, "rocblas" }, { backend::rocrand, "rocrand" }, - { backend::rocsolver, "rocsolver" }, { backend::unsupported, "unsupported" } + { backend::rocsolver, "rocsolver" }, { backend::portblas, "portblas" }, + { backend::cufft, "cufft" }, { backend::rocfft, "rocfft" }, + { backend::portfft, "portfft" }, { backend::unsupported, "unsupported" } }; } //namespace mkl diff --git a/include/oneapi/mkl/detail/backends_table.hpp b/include/oneapi/mkl/detail/backends_table.hpp index 6352f2a49..8e68674cc 100644 --- a/include/oneapi/mkl/detail/backends_table.hpp +++ b/include/oneapi/mkl/detail/backends_table.hpp @@ -41,7 +41,7 @@ namespace oneapi { namespace mkl { enum class device : uint16_t { x86cpu, intelgpu, nvidiagpu, amdgpu }; -enum class domain : uint16_t { blas, dft, lapack, rng }; +enum class domain : uint16_t { blas, dft, lapack, rng, sparse_blas }; static std::map>> libraries = { { domain::blas, @@ -51,25 +51,37 @@ static std::map>> libraries = LIB_NAME("blas_mklcpu"), #endif #ifdef ENABLE_NETLIB_BACKEND - LIB_NAME("blas_netlib") + LIB_NAME("blas_netlib"), +#endif +#ifdef ENABLE_PORTBLAS_BACKEND_INTEL_CPU + LIB_NAME("blas_portblas"), #endif } }, { device::intelgpu, { #ifdef ENABLE_MKLGPU_BACKEND - LIB_NAME("blas_mklgpu") + LIB_NAME("blas_mklgpu"), +#endif +#ifdef ENABLE_PORTBLAS_BACKEND_INTEL_GPU + LIB_NAME("blas_portblas"), #endif } }, { device::amdgpu, { #ifdef ENABLE_ROCBLAS_BACKEND - LIB_NAME("blas_rocblas") + LIB_NAME("blas_rocblas"), +#endif +#ifdef ENABLE_PORTBLAS_BACKEND_AMD_GPU + LIB_NAME("blas_portblas"), #endif } }, { device::nvidiagpu, { #ifdef ENABLE_CUBLAS_BACKEND - LIB_NAME("blas_cublas") + LIB_NAME("blas_cublas"), +#endif +#ifdef ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU + LIB_NAME("blas_portblas"), #endif } } } }, @@ -78,12 +90,36 @@ static std::map>> libraries = { #ifdef ENABLE_MKLCPU_BACKEND LIB_NAME("dft_mklcpu") +#endif +#ifdef ENABLE_PORTFFT_BACKEND + LIB_NAME("dft_portfft") #endif } }, { device::intelgpu, { #ifdef ENABLE_MKLGPU_BACKEND LIB_NAME("dft_mklgpu") +#endif +#ifdef ENABLE_PORTFFT_BACKEND + LIB_NAME("dft_portfft") +#endif + } }, + { device::amdgpu, + { +#ifdef ENABLE_ROCFFT_BACKEND + LIB_NAME("dft_rocfft") +#endif +#ifdef ENABLE_PORTFFT_BACKEND + LIB_NAME("dft_portfft") +#endif + } }, + { device::nvidiagpu, + { +#ifdef ENABLE_CUFFT_BACKEND + LIB_NAME("dft_cufft") +#endif +#ifdef ENABLE_PORTFFT_BACKEND + LIB_NAME("dft_portfft") #endif } } } }, @@ -137,13 +173,29 @@ static std::map>> libraries = #ifdef ENABLE_CURAND_BACKEND LIB_NAME("rng_curand") #endif - } } } } + } } } }, + + { domain::sparse_blas, + { { device::x86cpu, + { +#ifdef ENABLE_MKLCPU_BACKEND + LIB_NAME("sparse_blas_mklcpu") +#endif + } }, + { device::intelgpu, + { +#ifdef ENABLE_MKLGPU_BACKEND + LIB_NAME("sparse_blas_mklgpu") +#endif + } } } }, }; static std::map table_names = { { domain::blas, "mkl_blas_table" }, { domain::lapack, "mkl_lapack_table" }, { domain::dft, "mkl_dft_table" }, - { domain::rng, "mkl_rng_table" } }; + { domain::rng, "mkl_rng_table" }, + { domain::sparse_blas, + "mkl_sparse_blas_table" } }; } //namespace mkl } //namespace oneapi diff --git a/include/oneapi/mkl/dft/backward.hpp b/include/oneapi/mkl/dft/backward.hpp index b227248c6..3cd03e13b 100644 --- a/include/oneapi/mkl/dft/backward.hpp +++ b/include/oneapi/mkl/dft/backward.hpp @@ -26,51 +26,139 @@ #include #endif +#include "detail/types_impl.hpp" + namespace oneapi::mkl::dft { //Buffer version //In-place transform template -void compute_backward(descriptor_type &desc, sycl::buffer &inout); +void compute_backward(descriptor_type &desc, sycl::buffer &inout) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for data_type"); + + using fwd_type = typename detail::descriptor_info::forward_type; + auto type_corrected_inout = inout.template reinterpret( + detail::reinterpret_range(inout.size())); + get_commit(desc)->backward_ip_cc(desc, type_corrected_inout); +} //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> void compute_backward(descriptor_type &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im); + sycl::buffer &inout_im) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for data_type"); + + using scalar_type = typename detail::descriptor_info::scalar_type; + auto type_corrected_inout_re = inout_re.template reinterpret( + detail::reinterpret_range(inout_re.size())); + auto type_corrected_inout_im = inout_im.template reinterpret( + detail::reinterpret_range(inout_im.size())); + get_commit(desc)->backward_ip_rr(desc, type_corrected_inout_re, type_corrected_inout_im); +} //Out-of-place transform template void compute_backward(descriptor_type &desc, sycl::buffer &in, - sycl::buffer &out); + sycl::buffer &out) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for input_type"); + static_assert(detail::valid_compute_arg::value, + "unexpected type for output_type"); + + using fwd_type = typename detail::descriptor_info::forward_type; + using bwd_type = typename detail::descriptor_info::backward_type; + auto type_corrected_in = in.template reinterpret( + detail::reinterpret_range(in.size())); + auto type_corrected_out = out.template reinterpret( + detail::reinterpret_range(out.size())); + get_commit(desc)->backward_op_cc(desc, type_corrected_in, type_corrected_out); +} //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template void compute_backward(descriptor_type &desc, sycl::buffer &in_re, sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im); + sycl::buffer &out_im) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for input_type"); + static_assert(detail::valid_compute_arg::value, + "unexpected type for output_type"); + + using scalar_type = typename detail::descriptor_info::scalar_type; + auto type_corrected_in_re = in_re.template reinterpret( + detail::reinterpret_range(in_re.size())); + auto type_corrected_in_im = in_im.template reinterpret( + detail::reinterpret_range(in_im.size())); + auto type_corrected_out_re = out_re.template reinterpret( + detail::reinterpret_range(out_re.size())); + auto type_corrected_out_im = out_im.template reinterpret( + detail::reinterpret_range(out_im.size())); + get_commit(desc)->backward_op_rr(desc, type_corrected_in_re, type_corrected_in_im, + type_corrected_out_re, type_corrected_out_im); +} //USM version //In-place transform template sycl::event compute_backward(descriptor_type &desc, data_type *inout, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for data_type"); + + using fwd_type = typename detail::descriptor_info::forward_type; + return get_commit(desc)->backward_ip_cc(desc, reinterpret_cast(inout), + dependencies); +} //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, data_type *inout_im, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for data_type"); + + using scalar_type = typename detail::descriptor_info::scalar_type; + return get_commit(desc)->backward_ip_rr(desc, reinterpret_cast(inout_re), + reinterpret_cast(inout_im), + dependencies); +} //Out-of-place transform template sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for input_type"); + static_assert(detail::valid_compute_arg::value, + "unexpected type for output_type"); + + using fwd_type = typename detail::descriptor_info::forward_type; + using bwd_type = typename detail::descriptor_info::backward_type; + return get_commit(desc)->backward_op_cc(desc, reinterpret_cast(in), + reinterpret_cast(out), dependencies); +} //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template sycl::event compute_backward(descriptor_type &desc, input_type *in_re, input_type *in_im, output_type *out_re, output_type *out_im, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for input_type"); + static_assert(detail::valid_compute_arg::value, + "unexpected type for output_type"); + + using scalar_type = typename detail::descriptor_info::scalar_type; + return get_commit(desc)->backward_op_rr(desc, reinterpret_cast(in_re), + reinterpret_cast(in_im), + reinterpret_cast(out_re), + reinterpret_cast(out_im), dependencies); +} } // namespace oneapi::mkl::dft #endif // _ONEMKL_DFT_BACKWARD_HPP_ diff --git a/include/oneapi/mkl/dft/detail/commit_impl.hpp b/include/oneapi/mkl/dft/detail/commit_impl.hpp index ee4aee8dd..9e827f357 100644 --- a/include/oneapi/mkl/dft/detail/commit_impl.hpp +++ b/include/oneapi/mkl/dft/detail/commit_impl.hpp @@ -26,28 +26,46 @@ #include #endif +#include "descriptor_impl.hpp" +#include "external_workspace_helper.hpp" + namespace oneapi::mkl { enum class backend; } namespace oneapi::mkl::dft::detail { -enum class precision; -enum class domain; template class dft_values; template class commit_impl { + sycl::queue queue_; + mkl::backend backend_; + +public: + using descriptor_type = typename oneapi::mkl::dft::detail::descriptor; + using fwd_type = typename descriptor_info::forward_type; + using bwd_type = typename descriptor_info::backward_type; + using scalar_type = typename descriptor_info::scalar_type; + +protected: + external_workspace_helper external_workspace_helper_; + public: - commit_impl(sycl::queue queue, mkl::backend backend) : queue_(queue), backend_(backend) {} + commit_impl(sycl::queue queue, mkl::backend backend, + const dft::detail::dft_values &config_values) + : queue_(queue), + backend_(backend), + external_workspace_helper_(config_values.workspace_placement == + dft::detail::config_value::WORKSPACE_EXTERNAL) {} // rule of three - commit_impl(const commit_impl& other) = delete; - commit_impl& operator=(const commit_impl& other) = delete; + commit_impl(const commit_impl &other) = delete; + commit_impl &operator=(const commit_impl &other) = delete; virtual ~commit_impl() = default; - sycl::queue& get_queue() noexcept { + sycl::queue &get_queue() noexcept { return queue_; } @@ -55,13 +73,110 @@ class commit_impl { return backend_; } - virtual void* get_handle() noexcept = 0; + virtual void *get_handle() noexcept = 0; - virtual void commit(const dft_values&) = 0; + virtual void commit(const dft_values &) = 0; -private: - mkl::backend backend_; - sycl::queue queue_; + inline std::int64_t get_workspace_external_bytes() { + return external_workspace_helper_.get_rqd_workspace_bytes(*this); + } + + // set_workspace should be overridden for any backend that enables external workspaces. + // If these are overridden, get_workspace_external_bytes_impl must also be overridden. + // For backends that do not support external workspaces, these functions do not need to be overridden. + // When not overridden, external workspace support is faked: an external workspace can be set, + // and errors will be generated according to the specificiation, + // but the required workspace size will always be zero, and any given workspace will not actually be used. + virtual void set_workspace(scalar_type *usm_workspace) { + external_workspace_helper_.set_workspace_throw(*this, usm_workspace); + } + virtual void set_workspace(sycl::buffer &buffer_workspace) { + external_workspace_helper_.set_workspace_throw(*this, buffer_workspace); + } + + virtual void forward_ip_cc(descriptor_type &desc, sycl::buffer &inout) = 0; + virtual void forward_ip_rr(descriptor_type &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) = 0; + virtual void forward_op_cc(descriptor_type &desc, sycl::buffer &in, + sycl::buffer &out) = 0; + virtual void forward_op_rr(descriptor_type &desc, sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) = 0; + + virtual sycl::event forward_ip_cc(descriptor_type &desc, fwd_type *inout, + const std::vector &dependencies) = 0; + virtual sycl::event forward_ip_rr(descriptor_type &desc, scalar_type *inout_re, + scalar_type *inout_im, + const std::vector &dependencies) = 0; + virtual sycl::event forward_op_cc(descriptor_type &desc, fwd_type *in, bwd_type *out, + const std::vector &dependencies) = 0; + virtual sycl::event forward_op_rr(descriptor_type &desc, scalar_type *in_re, scalar_type *in_im, + scalar_type *out_re, scalar_type *out_im, + const std::vector &dependencies) = 0; + + virtual void backward_ip_cc(descriptor_type &desc, sycl::buffer &inout) = 0; + virtual void backward_ip_rr(descriptor_type &desc, sycl::buffer &inout_re, + sycl::buffer &inout_im) = 0; + virtual void backward_op_cc(descriptor_type &desc, sycl::buffer &in, + sycl::buffer &out) = 0; + virtual void backward_op_rr(descriptor_type &desc, sycl::buffer &in_re, + sycl::buffer &in_im, + sycl::buffer &out_re, + sycl::buffer &out_im) = 0; + + virtual sycl::event backward_ip_cc(descriptor_type &desc, fwd_type *inout, + const std::vector &dependencies) = 0; + virtual sycl::event backward_ip_rr(descriptor_type &desc, scalar_type *inout_re, + scalar_type *inout_im, + const std::vector &dependencies) = 0; + virtual sycl::event backward_op_cc(descriptor_type &desc, bwd_type *in, fwd_type *out, + const std::vector &dependencies) = 0; + virtual sycl::event backward_op_rr(descriptor_type &desc, scalar_type *in_re, + scalar_type *in_im, scalar_type *out_re, scalar_type *out_im, + const std::vector &dependencies) = 0; + + /** For compute calls, throw errors for the external workspace as required. + * @tparam ArgTs The non-descriptor arg(s) for the compute call. First one is used to check + * buffer or USM call. + * @param function_name The function name to user in generated exceptions. + */ + template + void compute_call_throw(const char *function_name) { + external_workspace_helper_.template compute_call_throw(function_name); + } + + /** Create an accessor out of the workspace buffer when required, to ensure correct dependency + * management for the buffer. To be used by backends that don't natively support sycl::buffers. + * @param function_name The function name to user in generated exceptions. + * @param cgh The command group handler to associate the accessor with. + */ + void add_buffer_workspace_dependency_if_rqd(const char *function_name, sycl::handler &cgh) { + external_workspace_helper_.add_buffer_dependency_if_rqd(function_name, cgh); + } + + /** If WORKSPACE_EXTERNAL is set, depend on the last USM workspace event added via set_last_usm_workspace_event. + * @param cgh The command group handler to associate the accessor with. + */ + void depend_on_last_usm_workspace_event_if_rqd(sycl::handler &cgh) { + external_workspace_helper_.depend_on_last_usm_workspace_event_if_rqd(cgh); + } + + /** If WORKSPACE_EXTERNAL is set, store the given event internally to allow it to be depended upon by + * subsequent calls to depend_on_last_usm_workspace_event. + * @param sycl_event The last usage of the USM workspace. + */ + void set_last_usm_workspace_event_if_rqd(sycl::event &sycl_event) { + external_workspace_helper_.set_last_usm_workspace_event_if_rqd(sycl_event); + } + +protected: + friend class external_workspace_helper; + + // This must be reimplemented for backends that support external workspaces. + virtual std::int64_t get_workspace_external_bytes_impl() { + return 0; + } }; } // namespace oneapi::mkl::dft::detail diff --git a/include/oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp b/include/oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp new file mode 100644 index 000000000..4e4ad2030 --- /dev/null +++ b/include/oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_CUFFT_HPP_ +#define _ONEMKL_DFT_CUFFT_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" + +namespace oneapi::mkl::dft::cufft { + +#include "oneapi/mkl/dft/detail/dft_ct.hxx" + +} // namespace oneapi::mkl::dft::cufft + +#endif // _ONEMKL_DFT_CUFFT_HPP_ diff --git a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp index d3bad05cb..a9c3f946c 100644 --- a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp +++ b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp @@ -45,6 +45,9 @@ inline commit_impl* get_commit(descriptor& desc); template class descriptor { +private: + using scalar_type = typename descriptor_info::scalar_type; + public: // Syntax for 1-dimensional DFT descriptor(std::int64_t length); @@ -52,6 +55,14 @@ class descriptor { // Syntax for d-dimensional DFT descriptor(std::vector dimensions); + // Copy operations are included in the oneAPI oneMKL specification, but not yet + // implemented here. If you need copies, please open an issue at + // https://github.com/oneapi-src/oneMKL/issues + + descriptor(descriptor&&); + + descriptor& operator=(descriptor&&); + ~descriptor(); void set_value(config_param param, ...); @@ -68,9 +79,25 @@ class descriptor { void commit(backend_selector selector); #endif +#ifdef ENABLE_CUFFT_BACKEND + void commit(backend_selector selector); +#endif + +#ifdef ENABLE_ROCFFT_BACKEND + void commit(backend_selector selector); +#endif + +#ifdef ENABLE_PORTFFT_BACKEND + void commit(backend_selector selector); +#endif + const dft_values& get_values() const noexcept { return values_; - }; + } + + void set_workspace(scalar_type* usm_workspace); + + void set_workspace(sycl::buffer& buffer_workspace); private: // Has a value when the descriptor is committed. @@ -80,6 +107,8 @@ class descriptor { dft_values values_; friend commit_impl* get_commit(descriptor&); + + using real_t = typename precision_t::real_t; }; template diff --git a/include/oneapi/mkl/dft/detail/dft_ct.hxx b/include/oneapi/mkl/dft/detail/dft_ct.hxx index 1a095427c..20cd537d8 100644 --- a/include/oneapi/mkl/dft/detail/dft_ct.hxx +++ b/include/oneapi/mkl/dft/detail/dft_ct.hxx @@ -25,96 +25,114 @@ ONEMKL_EXPORT dft::detail::commit_impl *create_commit( // BUFFER version +template +using scalar = typename detail::descriptor_info::scalar_type; +template +using fwd = typename detail::descriptor_info::forward_type; +template +using bwd = typename detail::descriptor_info::backward_type; + //In-place transform -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout); +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &inout); //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im); +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &inout_re, + sycl::buffer, 1> &inout_im); //Out-of-place transform -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in, - sycl::buffer &out); +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer, 1> &in, + sycl::buffer, 1> &out); //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im); +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &in_re, + sycl::buffer, 1> &in_im, + sycl::buffer, 1> &out_re, + sycl::buffer, 1> &out_im); //USM version //In-place transform -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout, - const std::vector &dependencies = {}); +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *inout, + const std::vector &dependencies); //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, - data_type *inout_im, - const std::vector &dependencies = {}); +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalar *inout_re, + scalar *inout_im, + const std::vector &dependencies); //Out-of-place transform -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out, - const std::vector &dependencies = {}); +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *in, + bwd *out, + const std::vector &dependencies); //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in_re, - input_type *in_im, output_type *out_re, - output_type *out_im, - const std::vector &dependencies = {}); +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalar *in_re, + scalar *in_im, + scalar *out_re, + scalar *out_im, + const std::vector &dependencies); // BUFFER version //In-place transform -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout); +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &inout); //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im); +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &inout_re, + sycl::buffer, 1> &inout_im); //Out-of-place transform -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in, - sycl::buffer &out); +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out); //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im); +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &in_re, + sycl::buffer, 1> &in_im, + sycl::buffer, 1> &out_re, + sycl::buffer, 1> &out_im); //USM version //In-place transform -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout, - const std::vector &dependencies = {}); +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd *inout, + const std::vector &dependencies); //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, - data_type *inout_im, - const std::vector &dependencies = {}); +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar *inout_re, + scalar *inout_im, + const std::vector &dependencies); //Out-of-place transform -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, - const std::vector &dependencies = {}); +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd *in, + fwd *out, + const std::vector &dependencies); //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in_re, - input_type *in_im, output_type *out_re, - output_type *out_im, - const std::vector &dependencies = {}); +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar *in_re, + scalar *in_im, + scalar *out_re, + scalar *out_im, + const std::vector &dependencies); diff --git a/include/oneapi/mkl/dft/detail/external_workspace_helper.hpp b/include/oneapi/mkl/dft/detail/external_workspace_helper.hpp new file mode 100644 index 000000000..b41dffc4c --- /dev/null +++ b/include/oneapi/mkl/dft/detail/external_workspace_helper.hpp @@ -0,0 +1,194 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_EXTERNAL_WORKSPACE_HELPER_HPP_ +#define _ONEMKL_DFT_EXTERNAL_WORKSPACE_HELPER_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" +#include "oneapi/mkl/dft/detail/commit_impl.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace detail { + +template +class external_workspace_helper { +public: + using commit_impl_t = commit_impl; + using scalar_t = typename commit_impl_t::scalar_type; + +private: + // Enum to represent whatever the workspace was set as. + enum class ext_workspace_type { + not_set, + usm, + buffer, + }; + + // Is an external workspace required? + bool m_ext_workspace_rqd; + + // Set workspace type, with optional workspaces. + ext_workspace_type m_workspace_type; + + // Minimum size of workspace in bytes. -1 indicates not set. + std::int64_t m_workspace_bytes_rqd; + + // Needed for adding dependencies to the SYCL runtime for backends that don't take + // the buffer as an argument. + std::optional> m_workspace_buffer; + + // Needed for creating dependencies between forward and backward calls in some backends. + sycl::event m_usm_workspace_last_dependency; + +public: + /** Constructor. + * @param ext_workspace_rqd True if WORKSPACE_PLACEMENT is set to WORKSPACE_EXTERNAL. + */ + constexpr external_workspace_helper(bool ext_workspace_rqd) + : m_ext_workspace_rqd(ext_workspace_rqd), + m_workspace_type(ext_workspace_type::not_set), + m_workspace_bytes_rqd(-1) {} + + /** Get the required workspace bytes for the backend's external workspace. + * @param committed_desc The backend's native descriptor. + */ + std::int64_t get_rqd_workspace_bytes(commit_impl_t& committed_desc) { + if (m_workspace_bytes_rqd == -1) { + m_workspace_bytes_rqd = committed_desc.get_workspace_external_bytes_impl(); + } + return m_workspace_bytes_rqd; + } + + /** Throw according to spec for setting the workspace. USM version. + * @param committed_desc The backend's native descriptor. + * @param usm_workspace A USM allocation for the workspace. Assumed to be sufficeintly large. + */ + void set_workspace_throw(commit_impl_t& committed_desc, scalar_t* usm_workspace) { + if (get_rqd_workspace_bytes(committed_desc) > 0 && usm_workspace == nullptr) { + throw mkl::invalid_argument("DFT", "set_workspace", + "Backend expected a non-null workspace pointer."); + } + m_ext_workspace_rqd = true; + m_workspace_type = ext_workspace_type::usm; + } + + /** Throw according to spec for setting the workspace. Buffer version. + * @param committed_desc The backend's native descriptor. + * @param buffer_workspace A buffer for the workspace + */ + void set_workspace_throw(commit_impl_t& committed_desc, + sycl::buffer& buffer_workspace) { + if (static_cast(get_rqd_workspace_bytes(committed_desc)) / sizeof(scalar_t) > + buffer_workspace.size()) { + throw mkl::invalid_argument("DFT", "set_workspace", "Provided workspace is too small"); + return; + } + if (buffer_workspace.is_sub_buffer()) { + throw mkl::invalid_argument("DFT", "set_workspace", + "Cannot use sub-buffers for workspace"); + return; + } + m_ext_workspace_rqd = true; + m_workspace_type = ext_workspace_type::buffer; + m_workspace_buffer = buffer_workspace; + } + + template + void compute_call_throw(const char* function_name) const { + constexpr bool is_pointer = std::is_pointer_v>; + if constexpr (is_pointer) { + usm_compute_call_throw(function_name); + } + else { + buffer_compute_call_throw(function_name); + } + } + + void add_buffer_dependency_if_rqd(const char* function_name, sycl::handler& cgh) { + if (m_ext_workspace_rqd) { + if (m_workspace_buffer) { + if (m_workspace_buffer->size()) { + m_workspace_buffer->template get_access(cgh); + } + } + else { + throw mkl::invalid_argument( + "DFT", function_name, + "Buffer external workspace must be used with buffer compute calls"); + } + } + } + + /** If WORKSPACE_EXTERNAL is set, depend on the last USM workspace event added via set_last_usm_workspace_event. + * @param cgh The command group handler to associate the accessor with. + */ + void depend_on_last_usm_workspace_event_if_rqd(sycl::handler& cgh) { + if (m_ext_workspace_rqd) { + cgh.depends_on(m_usm_workspace_last_dependency); + } + } + + /** If WORKSPACE_EXTERNAL is set, store the given event internally to allow it to be depended upon by + * subsequent calls to depend_on_last_usm_workspace_event. + * @param sycl_event The last usage of the USM workspace. + */ + void set_last_usm_workspace_event_if_rqd(sycl::event& sycl_event) { + if (m_ext_workspace_rqd) { + m_usm_workspace_last_dependency = sycl_event; + } + } + +private: + /** When a compute function using USM arguments is called, throw an exception if an incorrect workspace has been set. + * @param function_name The name of the function to use in the error. + */ + void usm_compute_call_throw(const char* function_name) const { + if (m_ext_workspace_rqd && m_workspace_type != ext_workspace_type::usm) { + throw mkl::invalid_argument( + "DFT", function_name, "USM external workspace must be used with usm compute calls"); + } + } + + /** When a compute function using buffer arguments is called, throw an exception if an incorrect workspace has been set. + * @param function_name The name of the function to use in the error. + */ + void buffer_compute_call_throw(const char* function_name) const { + if (m_ext_workspace_rqd && m_workspace_type != ext_workspace_type::buffer) { + throw mkl::invalid_argument( + "DFT", function_name, + "Buffer external workspace must be used with buffer compute calls"); + } + } +}; + +} // namespace detail +} // namespace dft +} // namespace mkl +} // namespace oneapi + +#endif //_ONEMKL_DFT_EXTERNAL_WORKSPACE_HELPER_HPP_ diff --git a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp index cfd2c6d99..00d4dd47b 100644 --- a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp @@ -29,26 +29,10 @@ #include "oneapi/mkl/detail/export.hpp" #include "oneapi/mkl/dft/detail/types_impl.hpp" -namespace oneapi { -namespace mkl { -namespace dft { - -namespace detail { -// Forward declarations -template -class commit_impl; - -template -class descriptor; -} // namespace detail - -namespace mklcpu { +namespace oneapi::mkl::dft::mklcpu { #include "oneapi/mkl/dft/detail/dft_ct.hxx" -} // namespace mklcpu -} // namespace dft -} // namespace mkl -} // namespace oneapi +} // namespace oneapi::mkl::dft::mklcpu #endif // _ONEMKL_DFT_MKLCPU_HPP_ diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp index a03e235c7..56a55a9f7 100644 --- a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -29,26 +29,10 @@ #include "oneapi/mkl/detail/export.hpp" #include "oneapi/mkl/dft/detail/types_impl.hpp" -namespace oneapi { -namespace mkl { -namespace dft { - -namespace detail { -// Forward declarations -template -class commit_impl; - -template -class descriptor; -} // namespace detail - -namespace mklgpu { +namespace oneapi::mkl::dft::mklgpu { #include "oneapi/mkl/dft/detail/dft_ct.hxx" -} // namespace mklgpu -} // namespace dft -} // namespace mkl -} // namespace oneapi +} // namespace oneapi::mkl::dft::mklgpu #endif // _ONEMKL_DFT_MKLGPU_HPP_ diff --git a/include/oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp b/include/oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp new file mode 100644 index 000000000..4617e8a5c --- /dev/null +++ b/include/oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp @@ -0,0 +1,39 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_PORTFFT_HPP_ +#define _ONEMKL_DFT_PORTFFT_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" + +namespace oneapi::mkl::dft::portfft { + +// We don't need the forward declarations of compute_xxxward templates (just need the create_commit template), but it doesn't hurt and keeps things simple. +#include "oneapi/mkl/dft/detail/dft_ct.hxx" + +} // namespace oneapi::mkl::dft::portfft + +#endif // _ONEMKL_DFT_PORTFFT_HPP_ diff --git a/include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp b/include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp new file mode 100644 index 000000000..fe3305680 --- /dev/null +++ b/include/oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp @@ -0,0 +1,38 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_ROCFFT_HPP_ +#define _ONEMKL_DFT_ROCFFT_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" + +namespace oneapi::mkl::dft::rocfft { + +#include "oneapi/mkl/dft/detail/dft_ct.hxx" + +} // namespace oneapi::mkl::dft::rocfft + +#endif // _ONEMKL_DFT_ROCFFT_HPP_ diff --git a/include/oneapi/mkl/dft/detail/types_impl.hpp b/include/oneapi/mkl/dft/detail/types_impl.hpp index 07a24a39a..60eb922ab 100644 --- a/include/oneapi/mkl/dft/detail/types_impl.hpp +++ b/include/oneapi/mkl/dft/detail/types_impl.hpp @@ -20,21 +20,117 @@ #ifndef _ONEMKL_DETAIL_TYPES_IMPL_HPP_ #define _ONEMKL_DETAIL_TYPES_IMPL_HPP_ +#if __has_include() +#include +#else +#include +#endif + #include #include #include +#include namespace oneapi { namespace mkl { namespace dft { namespace detail { -typedef int DFT_ERROR; - -#define DFT_NOTSET -1 +typedef long DFT_ERROR; enum class precision { SINGLE, DOUBLE }; + +template +struct precision_t { + using real_t = std::conditional_t; +}; + enum class domain { REAL, COMPLEX }; + +// Forward declarations +template +class commit_impl; + +template +class descriptor; + +template +constexpr bool always_false = false; + +template +struct descriptor_info { + static_assert(always_false, "Not a valid descriptor type"); +}; + +template <> +struct descriptor_info> { + using scalar_type = float; + using forward_type = float; + using backward_type = std::complex; +}; +template <> +struct descriptor_info> { + using scalar_type = float; + using forward_type = std::complex; + using backward_type = std::complex; +}; +template <> +struct descriptor_info> { + using scalar_type = double; + using forward_type = double; + using backward_type = std::complex; +}; +template <> +struct descriptor_info> { + using scalar_type = double; + using forward_type = std::complex; + using backward_type = std::complex; +}; + +// Get the scalar type associated with a descriptor. +template +using descriptor_scalar_t = typename descriptor_info::scalar_type; + +template +constexpr bool is_complex_dft = false; +template +constexpr bool is_complex_dft> = true; + +template +constexpr bool is_complex = false; +template +constexpr bool is_complex> = true; + +template +using is_one_of = typename std::bool_constant<(std::is_same_v || ...)>; + +template +using valid_compute_arg = typename std::bool_constant< + (std::is_same_v, float> && + is_one_of>::value) || + (std::is_same_v, double> && + is_one_of>::value)>; + +template +constexpr bool valid_ip_realreal_impl = + is_complex_dft&& std::is_same_v, data_t>; + +// compute the range of a reinterpreted buffer +template +std::size_t reinterpret_range(std::size_t size) { + if constexpr (sizeof(In) >= sizeof(Out)) { + static_assert(sizeof(In) % sizeof(Out) == 0); + return size * (sizeof(In) / sizeof(Out)); + } + else { + static_assert(sizeof(Out) % sizeof(In) == 0); + if (size % (sizeof(Out) / sizeof(In))) { + throw std::runtime_error("buffer cannot be evenly divived into the expected type"); + } + return size / (sizeof(Out) / sizeof(In)); + } +} + enum class config_param { FORWARD_DOMAIN, DIMENSION, @@ -52,17 +148,22 @@ enum class config_param { PLACEMENT, - INPUT_STRIDES, - OUTPUT_STRIDES, + INPUT_STRIDES [[deprecated("Use FWD/BWD_STRIDES")]], + OUTPUT_STRIDES [[deprecated("Use FWD/BWD_STRIDES")]], FWD_DISTANCE, BWD_DISTANCE, WORKSPACE, + WORKSPACE_PLACEMENT, + WORKSPACE_EXTERNAL_BYTES, ORDERING, TRANSPOSE, PACKED_FORMAT, - COMMIT_STATUS + COMMIT_STATUS, + + FWD_STRIDES, + BWD_STRIDES }; enum class config_value { @@ -91,17 +192,23 @@ enum class config_value { NONE, // for config_param::PACKED_FORMAT for storing conjugate-even finite sequence in real containers - CCE_FORMAT + CCE_FORMAT, + + // For config_param::WORKSPACE_PLACEMENT + WORKSPACE_AUTOMATIC, + WORKSPACE_EXTERNAL }; template class dft_values { private: - using real_t = std::conditional_t; + using real_t = typename precision_t::real_t; public: std::vector input_strides; std::vector output_strides; + std::vector fwd_strides; + std::vector bwd_strides; real_t bwd_scale; real_t fwd_scale; std::int64_t number_of_transforms; @@ -112,6 +219,7 @@ class dft_values { config_value real_storage; config_value conj_even_storage; config_value workspace; + config_value workspace_placement; config_value ordering; bool transpose; config_value packed_format; diff --git a/include/oneapi/mkl/dft/forward.hpp b/include/oneapi/mkl/dft/forward.hpp index 42e6c902c..e43c39ce0 100644 --- a/include/oneapi/mkl/dft/forward.hpp +++ b/include/oneapi/mkl/dft/forward.hpp @@ -26,52 +26,136 @@ #include #endif +#include "detail/types_impl.hpp" + namespace oneapi::mkl::dft { //Buffer version //In-place transform template -void compute_forward(descriptor_type &desc, sycl::buffer &inout); +void compute_forward(descriptor_type &desc, sycl::buffer &inout) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for data_type"); + + using fwd_type = typename detail::descriptor_info::forward_type; + auto type_corrected_inout = inout.template reinterpret( + detail::reinterpret_range(inout.size())); + get_commit(desc)->forward_ip_cc(desc, type_corrected_inout); +} //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> void compute_forward(descriptor_type &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im); + sycl::buffer &inout_im) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for data_type"); + + using scalar_type = typename detail::descriptor_info::scalar_type; + auto type_corrected_inout_re = inout_re.template reinterpret( + detail::reinterpret_range(inout_re.size())); + auto type_corrected_inout_im = inout_im.template reinterpret( + detail::reinterpret_range(inout_im.size())); + get_commit(desc)->forward_ip_rr(desc, type_corrected_inout_re, type_corrected_inout_im); +} //Out-of-place transform template void compute_forward(descriptor_type &desc, sycl::buffer &in, - sycl::buffer &out); + sycl::buffer &out) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for input_type"); + static_assert(detail::valid_compute_arg::value, + "unexpected type for output_type"); + + using fwd_type = typename detail::descriptor_info::forward_type; + using bwd_type = typename detail::descriptor_info::backward_type; + auto type_corrected_in = in.template reinterpret( + detail::reinterpret_range(in.size())); + auto type_corrected_out = out.template reinterpret( + detail::reinterpret_range(out.size())); + get_commit(desc)->forward_op_cc(desc, type_corrected_in, type_corrected_out); +} //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template void compute_forward(descriptor_type &desc, sycl::buffer &in_re, sycl::buffer &in_im, sycl::buffer &out_re, - sycl::buffer &out_im); + sycl::buffer &out_im) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for input_type"); + static_assert(detail::valid_compute_arg::value, + "unexpected type for output_type"); + + using scalar_type = typename detail::descriptor_info::scalar_type; + auto type_corrected_in_re = in_re.template reinterpret( + detail::reinterpret_range(in_re.size())); + auto type_corrected_in_im = in_im.template reinterpret( + detail::reinterpret_range(in_im.size())); + auto type_corrected_out_re = out_re.template reinterpret( + detail::reinterpret_range(out_re.size())); + auto type_corrected_out_im = out_im.template reinterpret( + detail::reinterpret_range(out_im.size())); + get_commit(desc)->forward_op_rr(desc, type_corrected_in_re, type_corrected_in_im, + type_corrected_out_re, type_corrected_out_im); +} //USM version //In-place transform template sycl::event compute_forward(descriptor_type &desc, data_type *inout, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for data_type"); + + using fwd_type = typename detail::descriptor_info::forward_type; + return get_commit(desc)->forward_ip_cc(desc, reinterpret_cast(inout), dependencies); +} //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template +template , bool> = true> sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, data_type *inout_im, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for data_type"); + using scalar_type = typename detail::descriptor_info::scalar_type; + return get_commit(desc)->forward_ip_rr(desc, reinterpret_cast(inout_re), + reinterpret_cast(inout_im), dependencies); +} //Out-of-place transform template sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for input_type"); + static_assert(detail::valid_compute_arg::value, + "unexpected type for output_type"); + using fwd_type = typename detail::descriptor_info::forward_type; + using bwd_type = typename detail::descriptor_info::backward_type; + return get_commit(desc)->forward_op_cc(desc, reinterpret_cast(in), + reinterpret_cast(out), dependencies); +} //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format template sycl::event compute_forward(descriptor_type &desc, input_type *in_re, input_type *in_im, output_type *out_re, output_type *out_im, - const std::vector &dependencies = {}); + const std::vector &dependencies = {}) { + static_assert(detail::valid_compute_arg::value, + "unexpected type for input_type"); + static_assert(detail::valid_compute_arg::value, + "unexpected type for output_type"); + + using scalar_type = typename detail::descriptor_info::scalar_type; + return get_commit(desc)->forward_op_rr(desc, reinterpret_cast(in_re), + reinterpret_cast(in_im), + reinterpret_cast(out_re), + reinterpret_cast(out_im), dependencies); +} } // namespace oneapi::mkl::dft #endif // _ONEMKL_DFT_FORWARD_HPP_ diff --git a/include/oneapi/mkl/exceptions.hpp b/include/oneapi/mkl/exceptions.hpp index c2db1d521..244c8c61d 100644 --- a/include/oneapi/mkl/exceptions.hpp +++ b/include/oneapi/mkl/exceptions.hpp @@ -47,7 +47,7 @@ class exception : public std::exception { : ""); } - const char *what() const noexcept { + const char *what() const noexcept override { return msg_.c_str(); } }; diff --git a/include/oneapi/mkl/lapack/exceptions.hpp b/include/oneapi/mkl/lapack/exceptions.hpp index 6460d6bff..da205cc1a 100644 --- a/include/oneapi/mkl/lapack/exceptions.hpp +++ b/include/oneapi/mkl/lapack/exceptions.hpp @@ -26,9 +26,9 @@ namespace lapack { class exception { public: exception(oneapi::mkl::exception *_ex, std::int64_t info, std::int64_t detail = 0) - : _ex(_ex), - _info(info), - _detail(detail) {} + : _info(info), + _detail(detail), + _ex(_ex) {} std::int64_t info() const { return _info; } diff --git a/include/oneapi/mkl/rng/device.hpp b/include/oneapi/mkl/rng/device.hpp new file mode 100644 index 000000000..a628395d2 --- /dev/null +++ b/include/oneapi/mkl/rng/device.hpp @@ -0,0 +1,28 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_SYCL_DEVICE_HPP__ +#define _MKL_RNG_SYCL_DEVICE_HPP__ + +#include "oneapi/mkl/rng/device/types.hpp" +#include "oneapi/mkl/rng/device/functions.hpp" +#include "oneapi/mkl/rng/device/distributions.hpp" +#include "oneapi/mkl/rng/device/engines.hpp" + +#endif // _MKL_RNG_SYCL_DEVICE_HPP__ diff --git a/include/oneapi/mkl/rng/device/detail/bernoulli_impl.hpp b/include/oneapi/mkl/rng/device/detail/bernoulli_impl.hpp new file mode 100644 index 000000000..83bb92f2d --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/bernoulli_impl.hpp @@ -0,0 +1,89 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_BERNOULLI_IMPL_HPP_ +#define _MKL_RNG_DEVICE_BERNOULLI_IMPL_HPP_ + +namespace oneapi::mkl::rng::device::detail { + +template +class distribution_base> { +public: + struct param_type { + param_type(float p) : p_(p) {} + float p_; + }; + + distribution_base(float p) : p_(p) { +#ifndef __SYCL_DEVICE_ONLY__ + if ((p > 1.0f) || (p < 0.0f)) { + throw oneapi::mkl::invalid_argument("rng", "bernoulli", "p < 0 || p > 1"); + } +#endif + } + + float p() const { + return p_; + } + + param_type param() const { + return param_type(p_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if ((pt.p_ > 1.0f) || (pt.p_ < 0.0f)) { + throw oneapi::mkl::invalid_argument("rng", "bernoulli", "p < 0 || p > 1"); + } +#endif + p_ = pt.p_; + } + +protected: + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + auto uni_res = engine.generate(0.0f, 1.0f); + if constexpr (EngineType::vec_size == 1) { + return IntType{ uni_res < p_ }; + } + else { + sycl::vec vec_out(IntType{ 0 }); + for (int i = 0; i < EngineType::vec_size; ++i) { + if (uni_res[i] < p_) { + vec_out[i] = IntType{ 1 }; + } + } + return vec_out; + } + } + + template + IntType generate_single(EngineType& engine) { + auto uni_res = engine.generate_single(0.0f, 1.0f); + return IntType{ uni_res < p_ }; + } + + float p_; +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_BERNOULLI_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/bits_impl.hpp b/include/oneapi/mkl/rng/device/detail/bits_impl.hpp new file mode 100644 index 000000000..aa68956d6 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/bits_impl.hpp @@ -0,0 +1,71 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_BITS_IMPL_HPP_ +#define _MKL_RNG_DEVICE_BITS_IMPL_HPP_ + +#include "engine_base.hpp" + +namespace oneapi::mkl::rng::device::detail { + +template +class distribution_base> { +protected: + template + auto generate(EngineType& engine) -> typename std::enable_if< + !std::is_same>::value, + typename std::conditional>::type>::type { + static_assert(std::is_same::value, + "oneMKL: bits works only with std::uint32_t"); + return engine.generate(); + } + + template + auto generate(EngineType& engine) -> typename std::enable_if< + std::is_same>::value, + typename std::conditional>::type>::type { + static_assert(std::is_same::value, + "oneMKL: bits for mcg59 works only with std::uint64_t"); + return engine.generate_bits(); + } + + template + typename std::enable_if>::value, + UIntType>::type + generate_single(EngineType& engine) { + static_assert(std::is_same::value, + "oneMKL: bits works only with std::uint32_t"); + return engine.generate_single(); + } + + template + typename std::enable_if>::value, + UIntType>::type + generate_single(EngineType& engine) { + static_assert(std::is_same::value, + "oneMKL: bits for mcg59 works only with std::uint64_t"); + return engine.generate_single(); + } +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_BITS_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/distribution_base.hpp b/include/oneapi/mkl/rng/device/detail/distribution_base.hpp new file mode 100644 index 000000000..e728a564c --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/distribution_base.hpp @@ -0,0 +1,73 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DISTRIBUTION_BASE_HPP_ +#define _MKL_RNG_DISTRIBUTION_BASE_HPP_ + +#include + +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/rng/device/types.hpp" + +namespace oneapi::mkl::rng::device { + +namespace detail { + +template +class distribution_base {}; + +} // namespace detail + +// declarations of distribution classes +template +class uniform; + +template +class gaussian; + +template +class lognormal; + +template +class uniform_bits; + +template +class bits; + +template +class exponential; + +template +class poisson; + +template +class bernoulli; + +} // namespace oneapi::mkl::rng::device + +#include "oneapi/mkl/rng/device/detail/uniform_impl.hpp" +#include "oneapi/mkl/rng/device/detail/gaussian_impl.hpp" +#include "oneapi/mkl/rng/device/detail/lognormal_impl.hpp" +#include "oneapi/mkl/rng/device/detail/bits_impl.hpp" +#include "oneapi/mkl/rng/device/detail/uniform_bits_impl.hpp" +#include "oneapi/mkl/rng/device/detail/exponential_impl.hpp" +#include "oneapi/mkl/rng/device/detail/poisson_impl.hpp" +#include "oneapi/mkl/rng/device/detail/bernoulli_impl.hpp" + +#endif // _MKL_RNG_DISTRIBUTION_BASE_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/engine_base.hpp b/include/oneapi/mkl/rng/device/detail/engine_base.hpp new file mode 100644 index 000000000..fc1aee16a --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/engine_base.hpp @@ -0,0 +1,43 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_ENGINE_BASE_HPP_ +#define _MKL_RNG_DEVICE_ENGINE_BASE_HPP_ + +#include + +#include + +namespace oneapi::mkl::rng::device::detail { + +// internal structure to specify state of engine +template +struct engine_state {}; + +template +class engine_base {}; + +} // namespace oneapi::mkl::rng::device::detail + +#include "oneapi/mkl/rng/device/detail/philox4x32x10_impl.hpp" +#include "oneapi/mkl/rng/device/detail/mrg32k3a_impl.hpp" +#include "oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp" +#include "oneapi/mkl/rng/device/detail/mcg59_impl.hpp" + +#endif // _MKL_RNG_DEVICE_ENGINE_BASE_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp new file mode 100644 index 000000000..cf712f0e5 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/exponential_impl.hpp @@ -0,0 +1,112 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_EXPONENTIAL_IMPL_HPP_ +#define _MKL_RNG_DEVICE_EXPONENTIAL_IMPL_HPP_ + +#include "vm_wrappers.hpp" + +namespace oneapi::mkl::rng::device::detail { + +template +class distribution_base> { +public: + struct param_type { + param_type(RealType a, RealType beta) : a_(a), beta_(beta) {} + RealType a_; + RealType beta_; + }; + + distribution_base(RealType a, RealType beta) : a_(a), beta_(beta) { +#ifndef __SYCL_DEVICE_ONLY__ + if (beta <= static_cast(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "exponential", "beta <= 0"); + } +#endif + } + + RealType a() const { + return a_; + } + + RealType beta() const { + return beta_; + } + + param_type param() const { + return param_type(a_, beta_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.beta_ <= static_cast(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "exponential", "beta <= 0"); + } +#endif + a_ = pt.a_; + beta_ = pt.beta_; + } + +protected: + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + using OutType = typename std::conditional>::type; + + OutType res = engine.generate(RealType(0), RealType(1)); + if constexpr (EngineType::vec_size == 1) { + res = ln_wrapper(res); + } + else { + for (int i = 0; i < EngineType::vec_size; ++i) { + res[i] = ln_wrapper(res[i]); + } + } + res = a_ - res * beta_; + if constexpr (std::is_same::value) { + res = sycl::fmax(res, OutType{ a_ }); + } + return res; + } + + template + RealType generate_single(EngineType& engine) { + RealType res = engine.generate_single(RealType(0), RealType(1)); + res = ln_wrapper(res); + res = a_ - res * beta_; + if constexpr (std::is_same::value) { + res = sycl::fmax(res, a_); + } + return res; + } + + RealType a_; + RealType beta_; + + friend class distribution_base< + oneapi::mkl::rng::device::poisson>; + friend class distribution_base< + oneapi::mkl::rng::device::poisson>; +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_EXPONENTIAL_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/gaussian_impl.hpp b/include/oneapi/mkl/rng/device/detail/gaussian_impl.hpp new file mode 100644 index 000000000..4588aea97 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/gaussian_impl.hpp @@ -0,0 +1,270 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_GAUSSIAN_IMPL_HPP_ +#define _MKL_RNG_DEVICE_GAUSSIAN_IMPL_HPP_ + +#include "vm_wrappers.hpp" + +namespace oneapi::mkl::rng::device::detail { + +// sqrt(2) +template +constexpr inline RealType sqrt2() { + return 0x1.6A09E6P+0f; // 1.414213562 +} + +template <> +constexpr inline double sqrt2() { + return 0x1.6A09E667F3BCDP+0; // 1.414213562 +} + +template +class distribution_base< + oneapi::mkl::rng::device::gaussian> { +public: + struct param_type { + param_type(RealType mean, RealType stddev) : mean_(mean), stddev_(stddev) {} + RealType mean_; + RealType stddev_; + }; + + distribution_base(RealType mean, RealType stddev) : mean_(mean), stddev_(stddev) { + flag_ = false; +#ifndef __SYCL_DEVICE_ONLY__ + if (stddev <= RealType(0)) { + throw oneapi::mkl::invalid_argument("rng", "gaussian", "stddev <= 0"); + } +#endif + } + + RealType mean() const { + return mean_; + } + + RealType stddev() const { + return stddev_; + } + + param_type param() const { + return param_type(mean_, stddev_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.stddev_ <= RealType(0)) { + throw oneapi::mkl::invalid_argument("rng", "gaussian", "stddev <= 0"); + } +#endif + mean_ = pt.mean_; + stddev_ = pt.stddev_; + } + +protected: + template + __attribute__((always_inline)) inline auto generate(EngineType& engine) -> + typename std::conditional>::type { + RealType u1, u2, u1_transformed; + + if constexpr (EngineType::vec_size == 1) { + RealType res; + if (!flag_) { + u1 = engine.generate(RealType(0), RealType(1)); + u2 = engine.generate(RealType(0), RealType(1)); + u1_transformed = ln_wrapper(u1); + u1_transformed = sqrt_wrapper(static_cast(-2.0) * u1_transformed); + res = u1_transformed * sinpi_wrapper(RealType(2) * u2) * stddev_ + mean_; + u1_transformed_ = u1_transformed; + u2_ = u2; + flag_ = true; + return res; + } + res = u1_transformed_ * cospi_wrapper(RealType(2) * u2_) * stddev_ + mean_; + flag_ = false; + return res; + } + else { + RealType sin, cos; + sycl::vec res; + if (!flag_) { + constexpr std::int32_t tail = EngineType::vec_size % 2; + auto uniform_res = engine.generate(RealType(0), RealType(1)); +#pragma unroll + for (std::int32_t i = 0; i < EngineType::vec_size - tail; i += 2) { + u1 = uniform_res[i]; + u2 = uniform_res[i + 1]; + u1_transformed = ln_wrapper(u1); + u1_transformed = sqrt_wrapper(static_cast(-2.0) * u1_transformed); + sin = sincospi_wrapper(RealType(2.0) * u2, cos); + res[i] = (u1_transformed * sin) * stddev_ + mean_; + res[i + 1] = (u1_transformed * cos) * stddev_ + mean_; + } + if constexpr (tail) { + u1 = uniform_res[EngineType::vec_size - 1]; + u2 = engine.generate_single(RealType(0), RealType(1)); + u1_transformed = ln_wrapper(u1); + u1_transformed = sqrt_wrapper(static_cast(-2.0) * u1_transformed); + res[EngineType::vec_size - 1] = + u1_transformed * sinpi_wrapper(RealType(2) * u2) * stddev_ + mean_; + u1_transformed_ = u1_transformed; + u2_ = u2; + flag_ = true; + } + return res; + } + + res[0] = u1_transformed_ * cospi_wrapper(RealType(2) * u2_) * stddev_ + mean_; + flag_ = false; + constexpr std::int32_t tail = (EngineType::vec_size - 1) % 2; +#pragma unroll + for (std::int32_t i = 1; i < EngineType::vec_size - tail; i += 2) { + u1 = engine.generate_single(RealType(0), RealType(1)); + u2 = engine.generate_single(RealType(0), RealType(1)); + u1_transformed = ln_wrapper(u1); + u1_transformed = sqrt_wrapper(static_cast(-2.0) * u1_transformed); + sin = sincospi_wrapper(RealType(2.0) * u2, cos); + res[i] = (u1_transformed * sin) * stddev_ + mean_; + res[i + 1] = (u1_transformed * cos) * stddev_ + mean_; + } + if constexpr (tail) { + u1 = engine.generate_single(RealType(0), RealType(1)); + u2 = engine.generate_single(RealType(0), RealType(1)); + u1_transformed = ln_wrapper(u1); + u1_transformed = sqrt_wrapper(static_cast(-2.0) * u1_transformed); + res[EngineType::vec_size - 1] = + u1_transformed * sinpi_wrapper(RealType(2) * u2) * stddev_ + mean_; + u1_transformed_ = u1_transformed; + u2_ = u2; + flag_ = true; + } + return res; + } + } + + template + __attribute__((always_inline)) inline RealType generate_single(EngineType& engine) { + RealType u1, u2, u1_transformed; + RealType res; + if (!flag_) { + u1 = engine.generate_single(RealType(0), RealType(1)); + u2 = engine.generate_single(RealType(0), RealType(1)); + u1_transformed = ln_wrapper(u1); + u1_transformed = sqrt_wrapper(static_cast(-2.0) * u1_transformed); + res = u1_transformed * sinpi_wrapper(RealType(2) * u2) * stddev_ + mean_; + u1_transformed_ = u1_transformed; + u2_ = u2; + flag_ = true; + return res; + } + res = u1_transformed_ * cospi_wrapper(RealType(2) * u2_) * stddev_ + mean_; + flag_ = false; + return res; + } + + RealType mean_; + RealType stddev_; + bool flag_ = false; + RealType u1_transformed_; + RealType u2_; + + friend class distribution_base< + oneapi::mkl::rng::device::lognormal>; + friend class distribution_base< + oneapi::mkl::rng::device::poisson>; + friend class distribution_base< + oneapi::mkl::rng::device::poisson>; +}; + +#if MKL_RNG_USE_BINARY_CODE + +template +class distribution_base> { +public: + struct param_type { + param_type(RealType mean, RealType stddev) : mean_(mean), stddev_(stddev) {} + RealType mean_; + RealType stddev_; + }; + + distribution_base(RealType mean, RealType stddev) : mean_(mean), stddev_(stddev) { +#ifndef __SYCL_DEVICE_ONLY__ + if (stddev <= RealType(0)) { + throw oneapi::mkl::invalid_argument("rng", "gaussian", "stddev <= 0"); + } +#endif + } + + RealType mean() const { + return mean_; + } + + RealType stddev() const { + return stddev_; + } + + param_type param() const { + return param_type(mean_, stddev_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.stddev_ <= RealType(0)) { + throw oneapi::mkl::invalid_argument("rng", "gaussian", "stddev <= 0"); + } +#endif + mean_ = pt.mean_; + stddev_ = pt.stddev_; + } + +protected: + template + __attribute__((always_inline)) inline auto generate(EngineType& engine) -> + typename std::conditional>::type { + if constexpr (EngineType::vec_size == 1) { + return generate_single(engine); + } + else { + RealType stddev = stddev_ * sqrt2(); + sycl::vec res; + sycl::vec u = + engine.generate(RealType(-1), RealType(1)); + for (std::int32_t i = 0; i < EngineType::vec_size; i++) { + res[i] = erf_inv_wrapper(u[i]); + } + return res * stddev + mean_; + } + } + + template + __attribute__((always_inline)) inline RealType generate_single(EngineType& engine) { + RealType stddev = stddev_ * sqrt2(); + RealType u = engine.generate_single(RealType(-1), RealType(1)); + return sycl::fma(erf_inv_wrapper(u), stddev, mean_); + } + + RealType mean_; + RealType stddev_; +}; +#endif + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_GAUSSIAN_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/lognormal_impl.hpp b/include/oneapi/mkl/rng/device/detail/lognormal_impl.hpp new file mode 100644 index 000000000..85e8b6d57 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/lognormal_impl.hpp @@ -0,0 +1,105 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_LOGNORMAL_IMPL_HPP_ +#define _MKL_RNG_DEVICE_LOGNORMAL_IMPL_HPP_ + +namespace oneapi::mkl::rng::device::detail { + +template +class distribution_base> { +public: + struct param_type { + param_type(RealType m, RealType s, RealType displ, RealType scale) + : m_(m), + s_(s), + displ_(displ), + scale_(scale) {} + RealType m_; + RealType s_; + RealType displ_; + RealType scale_; + }; + + distribution_base(RealType m, RealType s, RealType displ, RealType scale) + : gaussian_(m, s), + displ_(displ), + scale_(scale) { +#ifndef __SYCL_DEVICE_ONLY__ + if (scale <= static_cast(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "lognormal", "scale <= 0"); + } +#endif + } + + RealType m() const { + return gaussian_.mean(); + } + + RealType s() const { + return gaussian_.stddev(); + } + + RealType displ() const { + return displ_; + } + + RealType scale() const { + return scale_; + } + + param_type param() const { + return param_type(gaussian_.mean(), gaussian_.stddev(), displ_, scale_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.scale_ <= static_cast(0.0)) { + throw oneapi::mkl::invalid_argument("rng", "lognormal", "scale <= 0"); + } +#endif + gaussian_.param({ pt.m_, pt.s_ }); + displ_ = pt.displ_; + scale_ = pt.scale_; + } + +protected: + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + auto res = gaussian_.generate(engine); + return sycl::exp(res) * scale_ + displ_; + } + + template + RealType generate_single(EngineType& engine) { + RealType res = gaussian_.generate_single(engine); + return sycl::exp(res) * scale_ + displ_; + } + + distribution_base> + gaussian_; + RealType displ_; + RealType scale_; +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_LOGNORMAL_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp b/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp new file mode 100644 index 000000000..8f1294ac2 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/mcg31m1_impl.hpp @@ -0,0 +1,233 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_MCG31M1_IMPL_HPP_ +#define _MKL_RNG_DEVICE_MCG31M1_IMPL_HPP_ + +namespace oneapi::mkl::rng::device { + +template +class mcg31m1; + +namespace detail { + +template +constexpr sycl::vec select_vector_a_mcg31m1() { + if constexpr (VecSize == 1) + return sycl::vec(UINT64_C(1)); + else if constexpr (VecSize == 2) + return sycl::vec({ UINT64_C(1), UINT64_C(1132489760) }); + else if constexpr (VecSize == 3) + return sycl::vec( + { UINT64_C(1), UINT64_C(1132489760), UINT64_C(826537482) }); + else if constexpr (VecSize == 4) + return sycl::vec( + { UINT64_C(1), UINT64_C(1132489760), UINT64_C(826537482), UINT64_C(289798557) }); + else if constexpr (VecSize == 8) + return sycl::vec({ UINT64_C(1), UINT64_C(1132489760), UINT64_C(826537482), + UINT64_C(289798557), UINT64_C(480863449), + UINT64_C(1381340036), UINT64_C(1582925527), + UINT64_C(1918178478) }); + else + return sycl::vec( + { UINT64_C(1), UINT64_C(1132489760), UINT64_C(826537482), UINT64_C(289798557), + UINT64_C(480863449), UINT64_C(1381340036), UINT64_C(1582925527), UINT64_C(1918178478), + UINT64_C(1286028348), UINT64_C(482167044), UINT64_C(262060616), UINT64_C(1856662125), + UINT64_C(839877947), UINT64_C(1997268203), UINT64_C(458714024), + UINT64_C(650347998) }); +} + +// hipSYCL (AdaptiveCpp) doesn't support constexpr sycl::vec constructor +// that's why in case of hipSYCL backend sycl::vec is created as a local variable +#ifndef __HIPSYCL__ +template +struct mcg31m1_vector_a { + static constexpr sycl::vec vector_a = + select_vector_a_mcg31m1(); // powers of a +}; +#endif + +struct mcg31m1_param { + static constexpr std::uint32_t a = 1132489760; + static constexpr std::uint64_t m_64 = 0x000000007FFFFFFF; // 2^31 - 1 + static constexpr double m_fl = 2147483647.0; // 2^31 - 1 + static constexpr std::uint64_t bits = 31; +}; + +template +struct engine_state> { + std::uint32_t s; +}; + +namespace mcg31m1_impl { + +// Improved modulus x % (2^31 - 1) operation (possible to do for divisor (2^N +// -1), but MCG31M1 needs only 2^31 - 1) if we want to do x % (2^N -1) we can +// find out that: x = A + B * 2^N, where A = x % 2^N = x & 00..01..11 (binary) +// where quantity of 1 is N, B = x / 2^N = x >> N also x = A + B * (2^N - 1 + 1) +// = (A + B) + B * (2^N - 1), but (A + B) may be greater than (2^N - 1), that's +// why we put x1 = A + B = A' + B' * 2^N = ... until new (A + B) < (2^N - 1) for +// MCG31m1 N = 31 +template +static inline T custom_mod(std::uint64_t x) { + std::uint64_t b = x >> mcg31m1_param::bits; + std::uint64_t a = x & mcg31m1_param::m_64; + x = a + b; + b = x >> mcg31m1_param::bits; + a = x & mcg31m1_param::m_64; + return static_cast(a + b); +} + +template +static inline sycl::vec custom_mod( + const sycl::vec& x) { + sycl::vec b = x >> mcg31m1_param::bits; + sycl::vec a = x & mcg31m1_param::m_64; + sycl::vec res = a + b; + b = res >> mcg31m1_param::bits; + a = res & mcg31m1_param::m_64; + res = a + b; + return res.template convert(); +} + +static inline std::uint64_t power(std::uint64_t a, std::uint64_t n) { + std::uint64_t a2; + // initialize result by 1 for recurrence + std::uint32_t result = 1; + + if (n == 0) { + // return (a^0)%m = 1 + return std::uint64_t{ 1 }; + } + + // Recurrence loop + do { + // For each odd n + if (n & 1) { + a2 = static_cast(result) * a; + result = custom_mod(a2); + } + // n /= 2 + n >>= 1; + + a2 = a * a; + a = custom_mod(a2); + } while (n); + + return static_cast(result); +} + +template +static inline void skip_ahead(engine_state>& state, + std::uint64_t num_to_skip) { + std::uint64_t loc_A = power(static_cast(mcg31m1_param::a), num_to_skip); + state.s = custom_mod(loc_A * static_cast(state.s)); +} + +template +static inline void init(engine_state>& state, + std::uint32_t seed, std::uint64_t offset) { + state.s = custom_mod(seed); + if (state.s == 0) + state.s = 1; + skip_ahead(state, offset); +} + +template +static inline sycl::vec generate( + engine_state>& state) { + sycl::vec x(state.s); + sycl::vec res; +#ifndef __HIPSYCL__ + res = custom_mod(mcg31m1_vector_a::vector_a * x); +#else + // a workaround for hipSYCL (AdaptiveCpp) + res = custom_mod(select_vector_a_mcg31m1() * x); +#endif + state.s = + custom_mod(mcg31m1_param::a * static_cast(res[VecSize - 1])); + return res; +} + +template +static inline std::uint32_t generate_single( + engine_state>& state) { + std::uint32_t x = state.s; + state.s = custom_mod(mcg31m1_param::a * static_cast(state.s)); + return x; +} + +} // namespace mcg31m1_impl + +template +class engine_base> { +protected: + engine_base(std::uint32_t seed, std::uint64_t offset = 0) { + mcg31m1_impl::init(this->state_, seed, offset); + } + + template + auto generate(RealType a, RealType b) -> + typename std::conditional>::type { + sycl::vec res; + sycl::vec res_uint; + + RealType c = (b - a) / static_cast(mcg31m1_param::m_fl); + + res_uint = mcg31m1_impl::generate(this->state_); + + res = res_uint.template convert() * c + a; + + return res; + } + + auto generate() -> typename std::conditional>::type { + return mcg31m1_impl::generate(this->state_); + } + + template + RealType generate_single(RealType a, RealType b) { + RealType res; + std::uint32_t res_uint; + + RealType c = (b - a) / static_cast(mcg31m1_param::m_fl); + + res_uint = mcg31m1_impl::generate_single(this->state_); + + res = static_cast(res_uint) * c + a; + return res; + } + + std::uint32_t generate_single() { + return mcg31m1_impl::generate_single(this->state_); + } + + void skip_ahead(std::uint64_t num_to_skip) { + detail::mcg31m1_impl::skip_ahead(this->state_, num_to_skip); + } + + engine_state> state_; +}; + +} // namespace detail + +} // namespace oneapi::mkl::rng::device + +#endif // _MKL_RNG_DEVICE_MCG31M1_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp b/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp new file mode 100644 index 000000000..bc21eb607 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/mcg59_impl.hpp @@ -0,0 +1,275 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_MCG59_IMPL_HPP_ +#define _MKL_RNG_DEVICE_MCG59_IMPL_HPP_ + +namespace oneapi::mkl::rng::device { + +template +class mcg59; + +namespace detail { + +template +constexpr sycl::vec select_vector_a_mcg59() { + if constexpr (VecSize == 1) + return sycl::vec(UINT64_C(1)); + else if constexpr (VecSize == 2) + return sycl::vec({ UINT64_C(1), UINT64_C(0x113769B23C5FD) }); + else if constexpr (VecSize == 3) + return sycl::vec( + { UINT64_C(1), UINT64_C(0x113769B23C5FD), UINT64_C(0x65C69FC1A4D5C09) }); + else if constexpr (VecSize == 4) + return sycl::vec({ UINT64_C(1), UINT64_C(0x113769B23C5FD), + UINT64_C(0x65C69FC1A4D5C09), UINT64_C(0x1CE44D68E81E1E5) }); + else if constexpr (VecSize == 8) + return sycl::vec({ UINT64_C(1), UINT64_C(0x113769B23C5FD), + UINT64_C(0x65C69FC1A4D5C09), UINT64_C(0x1CE44D68E81E1E5), + UINT64_C(0x2F861CA52807851), UINT64_C(0x1CCDF2FE3A03D0D), + UINT64_C(0x707AB5B7C1E56D9), UINT64_C(0x6139AE457BD175) }); + else + return sycl::vec( + { UINT64_C(1), UINT64_C(0x113769B23C5FD), UINT64_C(0x65C69FC1A4D5C09), + UINT64_C(0x1CE44D68E81E1E5), UINT64_C(0x2F861CA52807851), UINT64_C(0x1CCDF2FE3A03D0D), + UINT64_C(0x707AB5B7C1E56D9), UINT64_C(0x6139AE457BD175), UINT64_C(0x171CF606D8C09A1), + UINT64_C(0x3764DC8D2D1691D), UINT64_C(0x50A1576CCF32A9), UINT64_C(0x499F3083ADC1E05), + UINT64_C(0x7A30C00B05283F1), UINT64_C(0x4FE299EB607DA2D), UINT64_C(0x51CCFD803CE3F79), + UINT64_C(0x58145D06A37D795) }); +} + +// hipSYCL (AdaptiveCpp) doesn't support constexpr sycl::vec constructor +// that's why in case of hipSYCL backend sycl::vec is created as a local variable +#ifndef __HIPSYCL__ +template +struct mcg59_vector_a { + static constexpr sycl::vec vector_a = + select_vector_a_mcg59(); // powers of a +}; +#endif + +struct mcg59_param { + static constexpr uint64_t a = 0x113769B23C5FD; // 13^13 + static constexpr uint64_t m_64 = 0x7FFFFFFFFFFFFFF; // 2^59 - 1 + static constexpr float m_fl = 576460752303423488.0f; // 2^59 +}; + +template +struct engine_state> { + std::uint64_t s; +}; + +namespace mcg59_impl { + +template +static inline T custom_mod(T x) { + return (x & mcg59_param::m_64); +} + +static inline std::uint64_t power(std::uint64_t a, std::uint64_t n) { + // initialize result by 1 for recurrency + std::uint64_t result = 1; + if (n == 0) { + // return (a^0)%m = 1 + return 1; + } + do { + // For each odd n + if (n & 1) { + result = custom_mod(result * a); + } + // n := n/2 + n >>= 1; + a = custom_mod(a * a); + } while (n); + + return result; +} + +template +static inline void skip_ahead(engine_state>& state, + std::uint64_t num_to_skip) { + std::uint64_t loc_A = power(mcg59_param::a, num_to_skip); + state.s = custom_mod(loc_A * state.s); +} + +template +static inline void init(engine_state>& state, + std::uint64_t seed, std::uint64_t offset) { + state.s = seed & mcg59_param::m_64; + if (state.s == 0) + state.s = 1; + + skip_ahead(state, offset); +} + +template +static inline sycl::vec generate( + engine_state>& state) { + sycl::vec res(state.s); +#ifndef __HIPSYCL__ + res = custom_mod(mcg59_vector_a::vector_a * res); +#else + // a workaround for hipSYCL (AdaptiveCpp) + res = custom_mod(select_vector_a_mcg59() * res); +#endif + state.s = custom_mod(mcg59_param::a * res[VecSize - 1]); + return res; +} + +template +static inline std::uint64_t generate_single( + engine_state>& state) { + std::uint64_t x = state.s; + state.s = custom_mod(mcg59_param::a * x); + return x; +} + +} // namespace mcg59_impl + +template +class engine_base> { +protected: + engine_base(std::uint64_t seed, std::uint64_t offset = 0) { + mcg59_impl::init(this->state_, seed, offset); + } + + template + auto generate(RealType a, RealType b) -> + typename std::conditional>::type { + sycl::vec res; + + RealType c = (b - a) / static_cast(mcg59_param::m_fl); + sycl::vec res_uint = mcg59_impl::generate(this->state_); + + res = res_uint.template convert() * c + a; + + return res; + } + + auto generate() -> typename std::conditional>::type { + return mcg59_impl::generate(this->state_); + } + + auto generate_bits() -> typename std::conditional>::type { + return mcg59_impl::generate(this->state_); + } + + template + auto generate_uniform_bits() -> + typename std::conditional>::type { + if constexpr (std::is_same::value) { + auto uni_res = mcg59_impl::generate(this->state_); + + if constexpr (VecSize == 1) { + return static_cast(uni_res[0] >> 27); + } + else { + sycl::vec vec_out; + + for (std::int32_t i = 0; i < VecSize; i++) { + vec_out[i] = static_cast(uni_res[i] >> 27); + } + + return vec_out; + } + } + else { + auto uni_res1 = mcg59_impl::generate(this->state_); + auto uni_res2 = mcg59_impl::generate(this->state_); + + if constexpr (VecSize == 1) { + uni_res1 >>= UIntType(27); + uni_res2 >>= UIntType(27); + + return (uni_res2 << UIntType(32)) + uni_res1; + } + else { + sycl::vec vec_out; + + for (int i = 0; i < VecSize; i++) { + uni_res1[i] >>= 27; + uni_res2[i] >>= 27; + } + + if constexpr (VecSize != 3) { + for (int i = 0; i < VecSize / 2; i++) { + vec_out[i] = (uni_res1[2 * i + 1] << 32) + uni_res1[2 * i]; + vec_out[i + VecSize / 2] = (uni_res2[2 * i + 1] << 32) + uni_res2[2 * i]; + } + } + else { + vec_out[0] = (uni_res1[1] << 32) + uni_res1[0]; + vec_out[1] = (uni_res2[0] << 32) + uni_res1[2]; + vec_out[2] = (uni_res2[2] << 32) + uni_res2[1]; + } + + return vec_out; + } + } + } + + template + RealType generate_single(RealType a, RealType b) { + RealType res; + std::uint64_t res_uint; + + RealType c = (b - a) / static_cast(mcg59_param::m_fl); + + res_uint = mcg59_impl::generate_single(this->state_); + res = static_cast(res_uint) * c + a; + + return res; + } + + auto generate_single() { + return mcg59_impl::generate_single(this->state_); + } + + template + auto generate_single_uniform_bits() { + if constexpr (std::is_same::value) { + auto uni_res = mcg59_impl::generate_single(this->state_) >> 27; + + return static_cast(uni_res); + } + else { + auto uni_res1 = mcg59_impl::generate_single(this->state_); + auto uni_res2 = mcg59_impl::generate_single(this->state_); + + uni_res1 >>= 27; + uni_res2 >>= 27; + + return (uni_res2 << 32) + uni_res1; + } + } + + void skip_ahead(std::uint64_t num_to_skip) { + detail::mcg59_impl::skip_ahead(this->state_, num_to_skip); + } + + engine_state> state_; +}; + +} // namespace detail +} // namespace oneapi::mkl::rng::device + +#endif // _MKL_RNG_DEVICE_MCG59_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/mrg32k3a_impl.hpp b/include/oneapi/mkl/rng/device/detail/mrg32k3a_impl.hpp new file mode 100644 index 000000000..596e625ad --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/mrg32k3a_impl.hpp @@ -0,0 +1,384 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +// References: +// [1] Bradley, Thomas & du Toit, Jacques & Giles, Mike & Tong, Robert & Woodhams, Paul. +// (2011). Parallelisation Techniques for Random Number Generators. +// GPU Computing Gems Emerald Edition. 10.1016/B978-0-12-384988-5.00016-4 + +#ifndef _MKL_RNG_DEVICE_MRG32K3A_IMPL_HPP_ +#define _MKL_RNG_DEVICE_MRG32K3A_IMPL_HPP_ + +#include "oneapi/mkl/rng/device/detail/mrg32k3a_skip_ahead_matrix.hpp" + +namespace oneapi::mkl::rng::device { + +template +class mrg32k3a; + +namespace detail { + +template +struct engine_state> { + std::uint32_t s[6]; +}; + +namespace mrg32k3a_impl { + +struct mrg32k3a_params { + static constexpr std::uint32_t m1 = 4294967087; + static constexpr std::uint32_t m2 = 4294944443; + static constexpr std::uint32_t a12 = 1403580; + static constexpr std::uint32_t a13 = 4294156359; + static constexpr std::uint32_t a21 = 527612; + static constexpr std::uint32_t a23 = 4293573854; + static constexpr std::uint32_t a13n = 810728; + static constexpr std::uint32_t a23n = 1370589; +}; + +template +struct two_pow_32_minus_m {}; + +template <> +struct two_pow_32_minus_m { + static constexpr std::int64_t val = 209; +}; + +template <> +struct two_pow_32_minus_m { + static constexpr std::int64_t val = 22853; +}; + +template +static inline void bit_shift_and_mask(T& in) { + T mask; + if constexpr (std::is_same_v) { + mask = 0x00000000ffffffffu; + } + else { + mask = 0x00000000ffffffff; + } + in = ((in >> 32) * two_pow_32_minus_m::val + (in & mask)); +} + +template +static inline void matr3x3_vec_mul_mod(std::uint32_t a[3][3], std::uint32_t x[3], + std::uint32_t y[3]) { + std::uint64_t temp[3] = { 0ull, 0ull, 0ull }; + for (int i = 0; i < 3; ++i) { + for (int k = 0; k < 3; ++k) { + std::uint64_t tmp = + static_cast(a[i][k]) * static_cast(x[k]); + bit_shift_and_mask(tmp); + bit_shift_and_mask(tmp); + if (tmp >= M) { + tmp -= M; + } + temp[i] += tmp; + } + bit_shift_and_mask(temp[i]); + if (temp[i] >= M) { + temp[i] -= M; + } + } + + for (int k = 0; k < 3; k++) { + y[k] = static_cast(temp[k]); + } + + return; +} + +template +static inline void matr3x3_mul_mod(std::uint32_t B[3][3], + const std::uint32_t _skip_ahead_matrix[3][3]) { + std::uint64_t temp[3][3] = { { 0ull, 0ull, 0ull }, { 0ull, 0ull, 0ull }, { 0ull, 0ull, 0ull } }; + + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + for (int k = 0; k < 3; ++k) { + std::uint64_t tmp = static_cast(B[i][k]) * + static_cast(_skip_ahead_matrix[k][j]); + bit_shift_and_mask(tmp); + if constexpr (mrg32k3a_params::m2 == M) { + bit_shift_and_mask(tmp); + } + if (tmp >= M) { + tmp -= M; + } + temp[i][j] += tmp; + } + bit_shift_and_mask(temp[i][j]); + if (temp[i][j] >= M) { + temp[i][j] -= M; + } + } + } + + for (int i = 0; i < 3; ++i) { + for (int j = 0; j < 3; ++j) { + B[i][j] = static_cast(temp[i][j]); + } + } +} + +template +static inline void vec3_pow_mod( + std::uint32_t x[3], std::uint64_t n, const std::uint64_t* skip_params, + const std::uint32_t _skip_ahead_matrix[quantity_of_3x3_matrices][3][3]) { + std::uint32_t B[3][3] = { { 1u, 0u, 0u }, { 0u, 1u, 0u }, { 0u, 0u, 1u } }; + + std::uint32_t off; + std::uint32_t mod; + std::uint64_t skip_param; + std::uint32_t bit_count = 0; // can be 0, 1, 2 + std::uint32_t bit_count_tmp; + + for (std::uint32_t j = 0; j < n; j++) { + skip_param = skip_params[j]; + off = 0; + bit_count_tmp = bit_count; + while (skip_param) { + // we have to multiply skip_param[1] by 2 and skip_params[2] by 4 only for the 1st iteration + // of the loop to get the required power of a power-of-eight matrice from a power of two + mod = (skip_param << static_cast(bit_count_tmp)) & + 7ull; // == (skip_param * _mult) % 8, _mult={1,2,4} + if (mod) { + // 7 - number of 3x3 matrices of some power of 8: 1*8^x, 2*8^x, ..., 7*8^x + // 7 * 21 - number of 3x3 matrices for each skip parameter + matr3x3_mul_mod(B, _skip_ahead_matrix[7 * 21 * j + off * 7 + (mod - 1)]); + } + skip_param = + skip_param / + (8ull >> static_cast(bit_count_tmp)); // == skip_param / (8 / _mult) + ++off; + bit_count_tmp = 0; + } + ++bit_count; + } + matr3x3_vec_mul_mod(B, x, x); +} + +template +static inline void skip_ahead(engine_state>& state, + std::uint64_t n, const std::uint64_t* num_to_skip_ptr) { + if (n > 3) { + n = 3; +#ifndef __SYCL_DEVICE_ONLY__ + throw oneapi::mkl::invalid_argument("rng", "mrg32k3a", + "period is 2 ^ 191, skip on more than 2^192"); +#endif + } + vec3_pow_mod(state.s, n, num_to_skip_ptr, skip_ahead_matrix[0]); + vec3_pow_mod(state.s + 3, n, num_to_skip_ptr, skip_ahead_matrix[1]); +} + +template +static inline void validate_seed(engine_state>& state) { + int i; + for (i = 0; i < 3; i++) { + if (state.s[i] >= mrg32k3a_params::m1) { + state.s[i] -= mrg32k3a_params::m1; + } + } + for (; i < 6; i++) { + if (state.s[i] >= mrg32k3a_params::m2) { + state.s[i] -= mrg32k3a_params::m2; + } + } + + if ((state.s[0]) == 0 && (state.s[1]) == 0 && (state.s[2]) == 0) { + state.s[0] = 1; + } + if ((state.s[3]) == 0 && (state.s[4]) == 0 && (state.s[5]) == 0) { + state.s[3] = 1; + } +} + +template +static inline void init(engine_state>& state, + std::uint64_t n, const std::uint32_t* seed_ptr, std::uint64_t n_offset, + const std::uint64_t* offset_ptr) { + std::uint64_t i; + if (n > 6) { + n = 6; + } + for (i = 0; i < n; i++) { + state.s[i] = seed_ptr[i]; + } + for (; i < 6; i++) { + state.s[i] = 1; + } + validate_seed(state); + mrg32k3a_impl::skip_ahead(state, n_offset, offset_ptr); +} + +template +static inline sycl::vec generate( + engine_state>& state) { + const std::int32_t num_elements = VecSize; + sycl::vec res; + std::int64_t x, y; + std::int32_t i = 0; + for (i = 0; i < num_elements; i++) { + x = mrg32k3a_params::a12 * static_cast(state.s[1]) - + mrg32k3a_params::a13n * static_cast(state.s[0]); + // perform modulus + bit_shift_and_mask(x); + if (x >= mrg32k3a_params::m1) + x -= mrg32k3a_params::m1; + x += ((x & 0x8000000000000000) >> 63) * mrg32k3a_params::m1; + y = mrg32k3a_params::a21 * static_cast(state.s[5]) - + mrg32k3a_params::a23n * static_cast(state.s[3]); + // perform modulus + bit_shift_and_mask(y); + bit_shift_and_mask(y); + if (y >= mrg32k3a_params::m2) + y -= mrg32k3a_params::m2; + y += ((y & 0x8000000000000000) >> 63) * mrg32k3a_params::m2; + state.s[0] = state.s[1]; + state.s[1] = state.s[2]; + state.s[2] = x; + state.s[3] = state.s[4]; + state.s[4] = state.s[5]; + state.s[5] = y; + if (x <= y) { + res[i] = x + (mrg32k3a_params::m1 - y); + } + else { + res[i] = x - y; + } + } + return res; +} + +template +static inline std::uint32_t generate_single( + engine_state>& state) { + std::uint32_t res; + std::int64_t x, y; + x = mrg32k3a_params::a12 * static_cast(state.s[1]) - + mrg32k3a_params::a13n * static_cast(state.s[0]); + // perform modulus + bit_shift_and_mask(x); + if (x >= mrg32k3a_params::m1) + x -= mrg32k3a_params::m1; + x += ((x & 0x8000000000000000) >> 63) * mrg32k3a_params::m1; + y = mrg32k3a_params::a21 * static_cast(state.s[5]) - + mrg32k3a_params::a23n * static_cast(state.s[3]); + // perform modulus + bit_shift_and_mask(y); + bit_shift_and_mask(y); + if (y >= mrg32k3a_params::m2) + y -= mrg32k3a_params::m2; + y += ((y & 0x8000000000000000) >> 63) * mrg32k3a_params::m2; + state.s[0] = state.s[1]; + state.s[1] = state.s[2]; + state.s[2] = x; + state.s[3] = state.s[4]; + state.s[4] = state.s[5]; + state.s[5] = y; + if (x <= y) { + res = x + (mrg32k3a_params::m1 - y); + } + else { + res = x - y; + } + + return res; +} + +} // namespace mrg32k3a_impl + +template +class engine_base> { +protected: + engine_base(std::uint32_t seed, std::uint64_t offset = 0) { + mrg32k3a_impl::init(this->state_, 1, &seed, 1, &offset); + } + + engine_base(std::uint64_t n, const std::uint32_t* seed, std::uint64_t offset = 0) { + mrg32k3a_impl::init(this->state_, n, seed, 1, &offset); + } + + engine_base(std::uint32_t seed, std::uint64_t n_offset, const std::uint64_t* offset_ptr) { + mrg32k3a_impl::init(this->state_, 1, &seed, n_offset, offset_ptr); + } + + engine_base(std::uint64_t n, const std::uint32_t* seed, std::uint64_t n_offset, + const std::uint64_t* offset_ptr) { + mrg32k3a_impl::init(this->state_, n, seed, n_offset, offset_ptr); + } + + template + auto generate(RealType a, RealType b) -> + typename std::conditional>::type { + sycl::vec res; + sycl::vec res_uint; + RealType c; + + c = (b - a) / (static_cast(mrg32k3a_impl::mrg32k3a_params::m1)); + + res_uint = mrg32k3a_impl::generate(this->state_); + + for (int i = 0; i < VecSize; i++) { + res[i] = (RealType)(res_uint[i]) * c + a; + } + return res; + } + + auto generate() -> typename std::conditional>::type { + return mrg32k3a_impl::generate(this->state_); + } + + template + RealType generate_single(RealType a, RealType b) { + RealType res; + std::uint32_t res_uint; + RealType c; + + c = (b - a) / (static_cast(mrg32k3a_impl::mrg32k3a_params::m1)); + + res_uint = mrg32k3a_impl::generate_single(this->state_); + + res = (RealType)(res_uint)*c + a; + + return res; + } + + std::uint32_t generate_single() { + return mrg32k3a_impl::generate_single(this->state_); + } + + void skip_ahead(std::uint64_t num_to_skip) { + detail::mrg32k3a_impl::skip_ahead(this->state_, 1, &num_to_skip); + } + + void skip_ahead(std::initializer_list num_to_skip) { + detail::mrg32k3a_impl::skip_ahead(this->state_, num_to_skip.size(), num_to_skip.begin()); + } + + engine_state> state_; +}; + +} // namespace detail +} // namespace oneapi::mkl::rng::device + +#endif // _MKL_RNG_DEVICE_MRG32K3A_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/mrg32k3a_skip_ahead_matrix.hpp b/include/oneapi/mkl/rng/device/detail/mrg32k3a_skip_ahead_matrix.hpp new file mode 100644 index 000000000..d1ea8c263 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/mrg32k3a_skip_ahead_matrix.hpp @@ -0,0 +1,3668 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_MRG32K3A_SKIP_AHEAD_MATRIX_HPP_ +#define _MKL_RNG_DEVICE_MRG32K3A_SKIP_AHEAD_MATRIX_HPP_ + +namespace oneapi::mkl::rng::device::detail { +namespace mrg32k3a_impl { + +constexpr std::size_t quantity_of_3x3_matrices = 455; // number of 3x3 matrices for skipping + +// There are 2 subsequences of numbers, each containing 455 3x3 matrices +static const std::uint32_t skip_ahead_matrix[2][quantity_of_3x3_matrices][3][3] = { + // Matrices for the first part of SkipAhead procedure + // Matrix for nskip = 1 * 8 ^ 0: + { { { 0, 1, 0 }, { 0, 0, 1 }, { 4294156359, 1403580, 0 } }, + // Matrix for nskip = 2 * 8 ^ 0: + { { 0, 0, 1 }, { 4294156359, 1403580, 0 }, { 0, 4294156359, 1403580 } }, + // Matrix for nskip = 3 * 8 ^ 0: + { { 4294156359, 1403580, 0 }, + { 0, 4294156359, 1403580 }, + { 244671815, 2941890554, 4294156359 } }, + // Matrix for nskip = 4 * 8 ^ 0: + { { 0, 4294156359, 1403580 }, + { 244671815, 2941890554, 4294156359 }, + { 149925673, 489343630, 2941890554 } }, + // Matrix for nskip = 5 * 8 ^ 0: + { { 244671815, 2941890554, 4294156359 }, + { 149925673, 489343630, 2941890554 }, + { 3782722441, 1831234280, 489343630 } }, + // Matrix for nskip = 6 * 8 ^ 0: + { { 149925673, 489343630, 2941890554 }, + { 3782722441, 1831234280, 489343630 }, + { 1527363550, 2758233149, 1831234280 } }, + // Matrix for nskip = 7 * 8 ^ 0: + { { 3782722441, 1831234280, 489343630 }, + { 1527363550, 2758233149, 1831234280 }, + { 4072640363, 939574583, 2758233149 } }, + // Matrix for nskip = 1 * 8 ^ 1: + { { 1527363550, 2758233149, 1831234280 }, + { 4072640363, 939574583, 2758233149 }, + { 2064391165, 3228066636, 939574583 } }, + // Matrix for nskip = 2 * 8 ^ 1: + { { 736416029, 2961816100, 342112271 }, + { 387300998, 1062452522, 2961816100 }, + { 2955879160, 340793741, 1062452522 } }, + // Matrix for nskip = 3 * 8 ^ 1: + { { 3830731060, 3351104823, 355092062 }, + { 4271633387, 3081436279, 3351104823 }, + { 2754512837, 673113417, 3081436279 } }, + // Matrix for nskip = 4 * 8 ^ 1: + { { 1243502014, 2218748291, 1709215645 }, + { 2019641772, 3847560959, 2218748291 }, + { 3866010231, 2305448679, 3847560959 } }, + // Matrix for nskip = 5 * 8 ^ 1: + { { 753665800, 3956261650, 1880714717 }, + { 3889504807, 299844503, 3956261650 }, + { 3555787878, 734199116, 299844503 } }, + // Matrix for nskip = 6 * 8 ^ 1: + { { 1402917279, 671479916, 279477115 }, + { 1066184965, 1957999095, 671479916 }, + { 3803905489, 2154014226, 1957999095 } }, + // Matrix for nskip = 7 * 8 ^ 1: + { { 1519817277, 3513041072, 37163717 }, + { 3823126416, 1394452522, 3513041072 }, + { 762181894, 1046733826, 1394452522 } }, + // Matrix for nskip = 1 * 8 ^ 2: + { { 3241775219, 3453352062, 3721871040 }, + { 4062454730, 3015754, 3453352062 }, + { 919711945, 613405362, 3015754 } }, + // Matrix for nskip = 2 * 8 ^ 2: + { { 1955221006, 1414472808, 1746037714 }, + { 3653507277, 1644962013, 1414472808 }, + { 3501544776, 2336229602, 1644962013 } }, + // Matrix for nskip = 3 * 8 ^ 2: + { { 2883496440, 2415235089, 3754924652 }, + { 2873360987, 3093961248, 2415235089 }, + { 2551531030, 3967481377, 3093961248 } }, + // Matrix for nskip = 4 * 8 ^ 2: + { { 1170096663, 49135452, 3441537107 }, + { 1857945175, 1649398389, 49135452 }, + { 333002869, 3109147376, 1649398389 } }, + // Matrix for nskip = 5 * 8 ^ 2: + { { 3782304170, 536558728, 1207462427 }, + { 2479820532, 1357898065, 536558728 }, + { 3967038637, 280429670, 1357898065 } }, + // Matrix for nskip = 6 * 8 ^ 2: + { { 1850220783, 2237648487, 4288110946 }, + { 778070070, 3729077970, 2237648487 }, + { 1095506872, 3284249345, 3729077970 } }, + // Matrix for nskip = 7 * 8 ^ 2: + { { 3963964167, 1824244353, 1280698295 }, + { 1736039316, 2491872331, 1824244353 }, + { 1645622379, 4226305484, 2491872331 } }, + // Matrix for nskip = 1 * 8 ^ 3: + { { 2299034194, 2297111910, 862649200 }, + { 1399961132, 996706937, 2297111910 }, + { 3439056503, 1481993076, 996706937 } }, + // Matrix for nskip = 2 * 8 ^ 3: + { { 4146310528, 458782589, 1007330283 }, + { 4241015765, 3979619964, 458782589 }, + { 553886495, 2186897562, 3979619964 } }, + // Matrix for nskip = 3 * 8 ^ 3: + { { 1146235803, 3119708691, 3977084597 }, + { 1030264372, 1706820424, 3119708691 }, + { 2210423860, 4154877869, 1706820424 } }, + // Matrix for nskip = 4 * 8 ^ 3: + { { 3630027893, 2130448350, 292773857 }, + { 1392525159, 1299285967, 2130448350 }, + { 2589171163, 1217405758, 1299285967 } }, + // Matrix for nskip = 5 * 8 ^ 3: + { { 3841954865, 948545149, 4067146304 }, + { 4218117763, 3741945962, 948545149 }, + { 1745368878, 730788749, 3741945962 } }, + // Matrix for nskip = 6 * 8 ^ 3: + { { 2341737887, 1393299668, 3386176735 }, + { 1655556841, 359678770, 1393299668 }, + { 2175543957, 3314680006, 359678770 } }, + // Matrix for nskip = 7 * 8 ^ 3: + { { 3121396438, 3210334684, 1062918236 }, + { 325732785, 2721675172, 3210334684 }, + { 3182328265, 241385543, 2721675172 } }, + // Matrix for nskip = 1 * 8 ^ 4: + { { 892409263, 1999175811, 2979225418 }, + { 1996163538, 2148702503, 1999175811 }, + { 3922720782, 103819730, 2148702503 } }, + // Matrix for nskip = 2 * 8 ^ 4: + { { 1586003016, 2114210471, 3240775579 }, + { 2777288607, 1400478398, 2114210471 }, + { 3018215420, 535326008, 1400478398 } }, + // Matrix for nskip = 3 * 8 ^ 4: + { { 377225862, 1098715579, 1378248654 }, + { 2452527982, 3677219860, 1098715579 }, + { 3805011027, 3962510930, 3677219860 } }, + // Matrix for nskip = 4 * 8 ^ 4: + { { 2188531273, 1783231160, 3576659343 }, + { 1908318389, 379210133, 1783231160 }, + { 554369329, 250053591, 379210133 } }, + // Matrix for nskip = 5 * 8 ^ 4: + { { 2249717607, 2266741858, 2040546316 }, + { 3093925525, 3510732546, 2266741858 }, + { 2244264588, 3926709784, 3510732546 } }, + // Matrix for nskip = 6 * 8 ^ 4: + { { 2349663769, 2339070143, 3651849809 }, + { 1360064932, 443349145, 2339070143 }, + { 2864061919, 590074072, 443349145 } }, + // Matrix for nskip = 7 * 8 ^ 4: + { { 299115015, 4017647307, 737449908 }, + { 1014398637, 352887003, 4017647307 }, + { 2268496651, 499779786, 352887003 } }, + // Matrix for nskip = 1 * 8 ^ 5: + { { 4022841636, 3951951872, 2143424240 }, + { 1046219306, 1591992468, 3951951872 }, + { 1510277444, 381333958, 1591992468 } }, + // Matrix for nskip = 2 * 8 ^ 5: + { { 2256493727, 3715182130, 642697923 }, + { 3615342722, 3975008370, 3715182130 }, + { 2405650329, 754337639, 3975008370 } }, + // Matrix for nskip = 3 * 8 ^ 5: + { { 3246129870, 3068844475, 3738266208 }, + { 668859604, 3798586786, 3068844475 }, + { 3275530821, 2740099935, 3798586786 } }, + // Matrix for nskip = 4 * 8 ^ 5: + { { 1286664224, 627406673, 963516608 }, + { 1541344588, 460768826, 627406673 }, + { 1089892553, 2717717970, 460768826 } }, + // Matrix for nskip = 5 * 8 ^ 5: + { { 2092934033, 2692683366, 2826944083 }, + { 1909409603, 3350132528, 2692683366 }, + { 3481095738, 3485350450, 3350132528 } }, + // Matrix for nskip = 6 * 8 ^ 5: + { { 1918719231, 2970279915, 803149880 }, + { 2389311995, 4195833089, 2970279915 }, + { 166509779, 2105299796, 4195833089 } }, + // Matrix for nskip = 7 * 8 ^ 5: + { { 3252663202, 2481165293, 694007918 }, + { 1921953957, 350878042, 2481165293 }, + { 1954500233, 1970948165, 350878042 } }, + // Matrix for nskip = 1 * 8 ^ 6: + { { 2956342842, 3471097641, 2353092905 }, + { 2996150472, 420480221, 3471097641 }, + { 2221681883, 372736411, 420480221 } }, + // Matrix for nskip = 2 * 8 ^ 6: + { { 420492906, 153526651, 3499730988 }, + { 2662640502, 3278195133, 153526651 }, + { 4086436419, 2510762118, 3278195133 } }, + // Matrix for nskip = 3 * 8 ^ 6: + { { 600928360, 715341436, 3127996992 }, + { 4276221887, 1953220754, 715341436 }, + { 2074032202, 163100603, 1953220754 } }, + // Matrix for nskip = 4 * 8 ^ 6: + { { 3310184147, 2228376089, 823220763 }, + { 3992771814, 1693168425, 2228376089 }, + { 2295790366, 1401872772, 1693168425 } }, + // Matrix for nskip = 5 * 8 ^ 6: + { { 1282168185, 2751813658, 602760489 }, + { 2254465781, 1232521545, 2751813658 }, + { 1025381169, 1981662800, 1232521545 } }, + // Matrix for nskip = 6 * 8 ^ 6: + { { 460755919, 4283511820, 3208183750 }, + { 3248110895, 730327118, 4283511820 }, + { 1386862282, 926261676, 730327118 } }, + // Matrix for nskip = 7 * 8 ^ 6: + { { 2392208153, 3129124418, 684400653 }, + { 4025364146, 1122067473, 3129124418 }, + { 773418203, 2967386517, 1122067473 } }, + // Matrix for nskip = 1 * 8 ^ 7: + { { 2529428830, 1497104068, 4253248635 }, + { 3746310018, 630867741, 1497104068 }, + { 627043435, 721725795, 630867741 } }, + // Matrix for nskip = 2 * 8 ^ 7: + { { 2571072593, 3039669025, 1591031831 }, + { 526054481, 661344445, 3039669025 }, + { 4246010312, 735391270, 661344445 } }, + // Matrix for nskip = 3 * 8 ^ 7: + { { 3781620139, 2917363935, 2936154555 }, + { 2668364492, 3297773364, 2917363935 }, + { 2501878263, 3438979384, 3297773364 } }, + // Matrix for nskip = 4 * 8 ^ 7: + { { 1847312821, 4042890210, 4241772463 }, + { 606605705, 2644799309, 4042890210 }, + { 2658402822, 1342278931, 2644799309 } }, + // Matrix for nskip = 5 * 8 ^ 7: + { { 3502592220, 3704088248, 4011400538 }, + { 2932838910, 1175764916, 3704088248 }, + { 2865336247, 2471593729, 1175764916 } }, + // Matrix for nskip = 6 * 8 ^ 7: + { { 3250474907, 3775615386, 3733878711 }, + { 1502779384, 287728234, 3775615386 }, + { 162441370, 246229618, 287728234 } }, + // Matrix for nskip = 7 * 8 ^ 7: + { { 749636765, 3227070913, 3120894575 }, + { 2853687796, 1910085226, 3227070913 }, + { 2453891386, 4230641571, 1910085226 } }, + // Matrix for nskip = 1 * 8 ^ 8: + { { 2409846784, 1096138313, 1416249993 }, + { 1501878241, 138013862, 1096138313 }, + { 1617749306, 1975136163, 138013862 } }, + // Matrix for nskip = 2 * 8 ^ 8: + { { 599453422, 73950522, 2965395603 }, + { 55354701, 3855242202, 73950522 }, + { 3981734504, 3354399019, 3855242202 } }, + // Matrix for nskip = 3 * 8 ^ 8: + { { 3515748818, 1941532786, 3590950415 }, + { 3557298699, 2872969148, 1941532786 }, + { 3200219335, 3657910297, 2872969148 } }, + // Matrix for nskip = 4 * 8 ^ 8: + { { 4271076381, 813410089, 3461955319 }, + { 1044920137, 3029005516, 813410089 }, + { 3501837362, 3321539504, 3029005516 } }, + // Matrix for nskip = 5 * 8 ^ 8: + { { 1749168476, 312277958, 960113158 }, + { 3444686334, 4207289909, 312277958 }, + { 2940543965, 559813450, 4207289909 } }, + // Matrix for nskip = 6 * 8 ^ 8: + { { 316005085, 3130396563, 3837877063 }, + { 1625744025, 2903706877, 3130396563 }, + { 201947523, 3713704391, 2903706877 } }, + // Matrix for nskip = 7 * 8 ^ 8: + { { 2725645318, 3806079268, 2159958180 }, + { 1110389513, 1295130289, 3806079268 }, + { 2596032611, 1951986222, 1295130289 } }, + // Matrix for nskip = 1 * 8 ^ 9: + { { 3058183515, 941408572, 1783998098 }, + { 1546486080, 4116985007, 941408572 }, + { 2247500745, 1460625377, 4116985007 } }, + // Matrix for nskip = 2 * 8 ^ 9: + { { 4216782514, 3352801941, 2315095646 }, + { 639029973, 94451952, 3352801941 }, + { 1242898773, 3964593332, 94451952 } }, + // Matrix for nskip = 3 * 8 ^ 9: + { { 3704530610, 1763750345, 4252200234 }, + { 3310872720, 3465004782, 1763750345 }, + { 1602573750, 530766064, 3465004782 } }, + // Matrix for nskip = 4 * 8 ^ 9: + { { 2264905138, 1926285644, 1108147171 }, + { 2390706911, 385258225, 1926285644 }, + { 3569882325, 3728744670, 385258225 } }, + // Matrix for nskip = 5 * 8 ^ 9: + { { 1104250853, 2649508927, 1011964068 }, + { 1303004323, 2245340871, 2649508927 }, + { 2225918280, 1790484033, 2245340871 } }, + // Matrix for nskip = 6 * 8 ^ 9: + { { 704130800, 2663175885, 3195438389 }, + { 2578332381, 377826974, 2663175885 }, + { 3055477316, 116744102, 377826974 } }, + // Matrix for nskip = 7 * 8 ^ 9: + { { 1534677729, 1538922981, 1955454860 }, + { 3358514099, 279668397, 1538922981 }, + { 1333529549, 1503627474, 279668397 } }, + // Matrix for nskip = 1 * 8 ^ 10: + { { 270679073, 1065683096, 2992662885 }, + { 4196917281, 2886425156, 1065683096 }, + { 749134119, 1849148167, 2886425156 } }, + // Matrix for nskip = 2 * 8 ^ 10: + { { 35689930, 1378151623, 951629713 }, + { 673810920, 948843427, 1378151623 }, + { 3808868984, 927013635, 948843427 } }, + // Matrix for nskip = 3 * 8 ^ 10: + { { 1708907294, 3971013929, 120796985 }, + { 341462694, 1820387182, 3971013929 }, + { 658508974, 1448556483, 1820387182 } }, + // Matrix for nskip = 4 * 8 ^ 10: + { { 1891490872, 1130489594, 3734864133 }, + { 1457450350, 3362920032, 1130489594 }, + { 638998846, 1401175590, 3362920032 } }, + // Matrix for nskip = 5 * 8 ^ 10: + { { 2493538871, 1119726169, 3415942617 }, + { 3041636598, 2163282065, 1119726169 }, + { 3770868549, 1056545317, 2163282065 } }, + // Matrix for nskip = 6 * 8 ^ 10: + { { 3254893662, 3244521128, 1199630310 }, + { 4235017122, 2943451417, 3244521128 }, + { 2697569444, 4187443436, 2943451417 } }, + // Matrix for nskip = 7 * 8 ^ 10: + { { 4046281084, 3800263816, 3215056790 }, + { 1654449614, 386290994, 3800263816 }, + { 1471940141, 481393463, 386290994 } }, + // Matrix for nskip = 1 * 8 ^ 11: + { { 2254459023, 2384691454, 1730098031 }, + { 2844861718, 1807491073, 2384691454 }, + { 351423668, 1570264155, 1807491073 } }, + // Matrix for nskip = 2 * 8 ^ 11: + { { 3047429268, 4245359555, 2449575498 }, + { 1797081212, 1237196477, 4245359555 }, + { 143400628, 3663731096, 1237196477 } }, + // Matrix for nskip = 3 * 8 ^ 11: + { { 2147359263, 1349445168, 2733446300 }, + { 1305907164, 210670816, 1349445168 }, + { 2509073771, 839244126, 210670816 } }, + // Matrix for nskip = 4 * 8 ^ 11: + { { 3313321106, 4263819658, 1047529624 }, + { 3719941673, 3155049403, 4263819658 }, + { 1981313839, 4281524426, 3155049403 } }, + // Matrix for nskip = 5 * 8 ^ 11: + { { 1429567203, 899246895, 3248764453 }, + { 2783815531, 108747348, 899246895 }, + { 256526168, 1467875854, 108747348 } }, + // Matrix for nskip = 6 * 8 ^ 11: + { { 2740000743, 1423127512, 1283194774 }, + { 700110581, 582760735, 1423127512 }, + { 571933335, 785351190, 582760735 } }, + // Matrix for nskip = 7 * 8 ^ 11: + { { 448747464, 852164586, 412380392 }, + { 497540878, 2374838356, 852164586 }, + { 1830234951, 2052902650, 2374838356 } }, + // Matrix for nskip = 1 * 8 ^ 12: + { { 2005252417, 3263186729, 1535805957 }, + { 2951515865, 1729281525, 3263186729 }, + { 1141249417, 2268963059, 1729281525 } }, + // Matrix for nskip = 2 * 8 ^ 12: + { { 2367065164, 83908466, 4294308508 }, + { 1352516724, 1416676049, 83908466 }, + { 1040867745, 1304732377, 1416676049 } }, + // Matrix for nskip = 3 * 8 ^ 12: + { { 2985917792, 4096493219, 1529477403 }, + { 1201774212, 2070059496, 4096493219 }, + { 1675108536, 3110356679, 2070059496 } }, + // Matrix for nskip = 4 * 8 ^ 12: + { { 3214147257, 1434230503, 2944821434 }, + { 2753040912, 4041536918, 1434230503 }, + { 1317260239, 338830578, 4041536918 } }, + // Matrix for nskip = 5 * 8 ^ 12: + { { 3409339184, 2193226133, 1795377731 }, + { 1348686132, 3710830263, 2193226133 }, + { 2242696089, 3564440066, 3710830263 } }, + // Matrix for nskip = 6 * 8 ^ 12: + { { 3189933295, 1475654090, 2785534584 }, + { 4286962883, 2397146654, 1475654090 }, + { 403072156, 2221537290, 2397146654 } }, + // Matrix for nskip = 7 * 8 ^ 12: + { { 741855424, 1898764790, 1822660758 }, + { 1315270526, 1027835471, 1898764790 }, + { 3142787072, 3867031443, 1027835471 } }, + // Matrix for nskip = 1 * 8 ^ 13: + { { 300628476, 2054743463, 1499597869 }, + { 1762244284, 1422043015, 2054743463 }, + { 3581125669, 1207561803, 1422043015 } }, + // Matrix for nskip = 2 * 8 ^ 13: + { { 4171745404, 4064983592, 1934508265 }, + { 3049723261, 1744636487, 4064983592 }, + { 947753516, 3952135907, 1744636487 } }, + // Matrix for nskip = 3 * 8 ^ 13: + { { 392234088, 1933162500, 3586081024 }, + { 4234172394, 2757237142, 1933162500 }, + { 3177450083, 2703743057, 2757237142 } }, + // Matrix for nskip = 4 * 8 ^ 13: + { { 1625369148, 3577024659, 2778677259 }, + { 1729967818, 1049600974, 3577024659 }, + { 2089137344, 1569794605, 1049600974 } }, + // Matrix for nskip = 5 * 8 ^ 13: + { { 24259337, 1099944220, 56936276 }, + { 2473082148, 2484906695, 1099944220 }, + { 4143714563, 1902230902, 2484906695 } }, + // Matrix for nskip = 6 * 8 ^ 13: + { { 53562000, 2164320300, 319591773 }, + { 480516705, 2016775973, 2164320300 }, + { 3670445841, 1306292301, 2016775973 } }, + // Matrix for nskip = 7 * 8 ^ 13: + { { 1588148001, 2552094779, 2777917575 }, + { 3446764329, 4181915770, 2552094779 }, + { 2748502268, 1366641757, 4181915770 } }, + // Matrix for nskip = 1 * 8 ^ 14: + { { 1373068765, 3958611830, 569117280 }, + { 410042396, 3551255470, 3958611830 }, + { 869476379, 1680625376, 3551255470 } }, + // Matrix for nskip = 2 * 8 ^ 14: + { { 2108618602, 2543645250, 913717833 }, + { 2111984988, 1012482542, 2543645250 }, + { 2545745615, 3141042890, 1012482542 } }, + // Matrix for nskip = 3 * 8 ^ 14: + { { 1200101967, 3500039413, 1380082835 }, + { 1489246316, 1939611745, 3500039413 }, + { 1721948148, 3454434256, 1939611745 } }, + // Matrix for nskip = 4 * 8 ^ 14: + { { 1157293598, 584852249, 2272893205 }, + { 1631801979, 3013855247, 584852249 }, + { 3977310441, 82049263, 3013855247 } }, + // Matrix for nskip = 5 * 8 ^ 14: + { { 3527704969, 2070084361, 2336461093 }, + { 675176428, 59273233, 2070084361 }, + { 215288790, 1628101656, 59273233 } }, + // Matrix for nskip = 6 * 8 ^ 14: + { { 3037143591, 2883460010, 26163475 }, + { 1380682893, 3598790241, 2883460010 }, + { 1573828863, 3515570245, 3598790241 } }, + // Matrix for nskip = 7 * 8 ^ 14: + { { 2503812675, 2054481550, 2095990336 }, + { 4200011507, 3373769861, 2054481550 }, + { 1172973983, 1101682881, 3373769861 } }, + // Matrix for nskip = 1 * 8 ^ 15: + { { 3580234334, 3137526662, 2403875621 }, + { 3580869206, 3670086228, 3137526662 }, + { 656744553, 1764904195, 3670086228 } }, + // Matrix for nskip = 2 * 8 ^ 15: + { { 2792496861, 3634185196, 3887031679 }, + { 3601823850, 3464838365, 3634185196 }, + { 3136165138, 2842987937, 3464838365 } }, + // Matrix for nskip = 3 * 8 ^ 15: + { { 860869470, 981305692, 955067142 }, + { 1287512071, 3232580086, 981305692 }, + { 1932329582, 2220460662, 3232580086 } }, + // Matrix for nskip = 4 * 8 ^ 15: + { { 1362557480, 3230022138, 4278720212 }, + { 3427386258, 3848976950, 3230022138 }, + { 2109817045, 2441486578, 3848976950 } }, + // Matrix for nskip = 5 * 8 ^ 15: + { { 2708545360, 267497185, 2662390285 }, + { 13298153, 1401050440, 267497185 }, + { 2610290298, 574376174, 1401050440 } }, + // Matrix for nskip = 6 * 8 ^ 15: + { { 4064509494, 1054794505, 2873059524 }, + { 2518650890, 2583418592, 1054794505 }, + { 2277374582, 2950188629, 2583418592 } }, + // Matrix for nskip = 7 * 8 ^ 15: + { { 43539574, 3585947086, 1551803386 }, + { 4188500293, 3646000753, 3585947086 }, + { 1152314996, 3244390048, 3646000753 } }, + // Matrix for nskip = 1 * 8 ^ 16: + { { 1198519135, 2007945401, 3868481 }, + { 3335076429, 2082683147, 2007945401 }, + { 2341088247, 888193479, 2082683147 } }, + // Matrix for nskip = 2 * 8 ^ 16: + { { 3473925387, 3193380570, 565138859 }, + { 307060547, 782210925, 3193380570 }, + { 167617770, 2180014252, 782210925 } }, + // Matrix for nskip = 3 * 8 ^ 16: + { { 3946174395, 938410993, 2583257939 }, + { 898527522, 1909350615, 938410993 }, + { 1517357015, 2538479259, 1909350615 } }, + // Matrix for nskip = 4 * 8 ^ 16: + { { 3811588895, 3303532086, 2766583698 }, + { 908630605, 2665400165, 3303532086 }, + { 2499994113, 3316180851, 2665400165 } }, + // Matrix for nskip = 5 * 8 ^ 16: + { { 2828295511, 296464469, 3400652741 }, + { 3697213244, 3884416364, 296464469 }, + { 2902099262, 1705355356, 3884416364 } }, + // Matrix for nskip = 6 * 8 ^ 16: + { { 3952581582, 91397022, 1472690314 }, + { 2332659451, 3813545212, 91397022 }, + { 2942299995, 3287843695, 3813545212 } }, + // Matrix for nskip = 7 * 8 ^ 16: + { { 1334460780, 861234488, 2817452481 }, + { 435895955, 3356827989, 861234488 }, + { 1590379239, 2041861019, 3356827989 } }, + // Matrix for nskip = 1 * 8 ^ 17: + { { 4288926968, 3033075037, 1505732852 }, + { 1531633406, 645804125, 3033075037 }, + { 2942690261, 2205365640, 645804125 } }, + // Matrix for nskip = 2 * 8 ^ 17: + { { 3976196483, 3651411522, 1652430357 }, + { 1690405883, 1294990760, 3651411522 }, + { 209339647, 3088484327, 1294990760 } }, + // Matrix for nskip = 3 * 8 ^ 17: + { { 3313281387, 404839765, 4119379625 }, + { 1282760808, 1769786574, 404839765 }, + { 2156822533, 2134509408, 1769786574 } }, + // Matrix for nskip = 4 * 8 ^ 17: + { { 3171589548, 2291131070, 2093793287 }, + { 2997812074, 4093879780, 2291131070 }, + { 3255666800, 858124816, 4093879780 } }, + // Matrix for nskip = 5 * 8 ^ 17: + { { 2671377286, 4060168649, 2412035287 }, + { 2560486338, 828012431, 4060168649 }, + { 431779937, 1288430895, 828012431 } }, + // Matrix for nskip = 6 * 8 ^ 17: + { { 3419357098, 2547678446, 3186955890 }, + { 3335475366, 2875872016, 2547678446 }, + { 1190772134, 216187195, 2875872016 } }, + // Matrix for nskip = 7 * 8 ^ 17: + { { 2462780486, 3788991986, 2965830319 }, + { 4101189674, 1696959105, 3788991986 }, + { 170171245, 376763544, 1696959105 } }, + // Matrix for nskip = 1 * 8 ^ 18: + { { 4113016361, 2999667479, 3995043314 }, + { 1333973326, 4007774239, 2999667479 }, + { 3322921863, 4278103786, 4007774239 } }, + // Matrix for nskip = 2 * 8 ^ 18: + { { 925786347, 2109676036, 1879981040 }, + { 1701566570, 1489702270, 2109676036 }, + { 2719807628, 158549605, 1489702270 } }, + // Matrix for nskip = 3 * 8 ^ 18: + { { 988998360, 4224987734, 2705609303 }, + { 3781735882, 3210618179, 4224987734 }, + { 2000646801, 3763764745, 3210618179 } }, + // Matrix for nskip = 4 * 8 ^ 18: + { { 2255405265, 3460246357, 218033453 }, + { 2135115875, 359516994, 3460246357 }, + { 3568862459, 3114762683, 359516994 } }, + // Matrix for nskip = 5 * 8 ^ 18: + { { 3151385849, 2749420870, 1663192542 }, + { 3858805987, 658557447, 2749420870 }, + { 3895454596, 3780884000, 658557447 } }, + // Matrix for nskip = 6 * 8 ^ 18: + { { 1720065491, 953484022, 1382647120 }, + { 1315666944, 2456296663, 953484022 }, + { 572064418, 2149791939, 2456296663 } }, + // Matrix for nskip = 7 * 8 ^ 18: + { { 2767100879, 4015038188, 1215355080 }, + { 3185998778, 1592475141, 4015038188 }, + { 135551392, 4171059118, 1592475141 } }, + // Matrix for nskip = 1 * 8 ^ 19: + { { 773148471, 4117539411, 3073622315 }, + { 3807175775, 186466108, 4117539411 }, + { 2842197411, 651334129, 186466108 } }, + // Matrix for nskip = 2 * 8 ^ 19: + { { 615242951, 1475251263, 3586439101 }, + { 1693917167, 3058812486, 1475251263 }, + { 568701600, 1164226398, 3058812486 } }, + // Matrix for nskip = 3 * 8 ^ 19: + { { 3729302216, 1041711449, 2647679194 }, + { 3878048889, 135488725, 1041711449 }, + { 508494460, 2178143073, 135488725 } }, + // Matrix for nskip = 4 * 8 ^ 19: + { { 1632636204, 15370275, 2061555515 }, + { 4187505695, 1741164221, 15370275 }, + { 2882176274, 3978412194, 1741164221 } }, + // Matrix for nskip = 5 * 8 ^ 19: + { { 4199667935, 4240821442, 3087593298 }, + { 2968278570, 2185585470, 4240821442 }, + { 2826850420, 371506848, 2185585470 } }, + // Matrix for nskip = 6 * 8 ^ 19: + { { 4002434761, 1455254388, 1267013695 }, + { 2324442395, 2192287989, 1455254388 }, + { 3389390262, 2190852671, 2192287989 } }, + // Matrix for nskip = 7 * 8 ^ 19: + { { 3722528722, 3193070982, 1527096340 }, + { 3155996013, 2278658572, 3193070982 }, + { 2051186788, 4289100465, 2278658572 } }, + // Matrix for nskip = 1 * 8 ^ 20: + { { 3446066703, 344820524, 74213775 }, + { 1008543583, 2579620192, 344820524 }, + { 3753911358, 1538453821, 2579620192 } }, + // Matrix for nskip = 2 * 8 ^ 20: + { { 3600859892, 1269921024, 4069458760 }, + { 2050939727, 2222725697, 1269921024 }, + { 3208347646, 690898125, 2222725697 } }, + // Matrix for nskip = 3 * 8 ^ 20: + { { 2580978896, 2572090525, 3334144098 }, + { 804558063, 250626667, 2572090525 }, + { 843125518, 1038659713, 250626667 } }, + // Matrix for nskip = 4 * 8 ^ 20: + { { 599407451, 2806239788, 1742216102 }, + { 975123999, 764869161, 2806239788 }, + { 2729710367, 1845257036, 764869161 } }, + // Matrix for nskip = 5 * 8 ^ 20: + { { 1900612628, 1237821080, 3847187360 }, + { 4059416755, 2650131939, 1237821080 }, + { 31199658, 2064718263, 2650131939 } }, + // Matrix for nskip = 6 * 8 ^ 20: + { { 1347324880, 3034196764, 3435152676 }, + { 2459581108, 68307108, 3034196764 }, + { 4060225449, 1313975073, 68307108 } }, + // Matrix for nskip = 7 * 8 ^ 20: + { { 832405527, 4273872816, 2483412578 }, + { 1083671641, 2619838177, 4273872816 }, + { 3452165941, 3089879239, 2619838177 } }, + // Matrix for nskip = 1 * 8 ^ 21: + { { 967330218, 3464884028, 3444447102 }, + { 580449578, 1343714307, 3464884028 }, + { 1775329096, 4027221761, 1343714307 } }, + // Matrix for nskip = 2 * 8 ^ 21: + { { 3426136514, 4123590610, 2477690850 }, + { 1284315665, 1604068527, 4123590610 }, + { 1818147893, 320435440, 1604068527 } }, + // Matrix for nskip = 3 * 8 ^ 21: + { { 2183845304, 1753369147, 3320030113 }, + { 1615069375, 2429599106, 1753369147 }, + { 4089942461, 816400070, 2429599106 } }, + // Matrix for nskip = 4 * 8 ^ 21: + { { 2678132557, 89090276, 2719996384 }, + { 607972119, 3383659282, 89090276 }, + { 480221151, 2265789281, 3383659282 } }, + // Matrix for nskip = 5 * 8 ^ 21: + { { 1549131095, 4063932361, 140002783 }, + { 3213919212, 3321129811, 4063932361 }, + { 2806676458, 1803235719, 3321129811 } }, + // Matrix for nskip = 6 * 8 ^ 21: + { { 2289583273, 1236554533, 358687301 }, + { 1498394381, 1159516887, 1236554533 }, + { 359182081, 4214998734, 1159516887 } }, + // Matrix for nskip = 7 * 8 ^ 21: + { { 1434974522, 4046133592, 349947526 }, + { 383007031, 4052481195, 4046133592 }, + { 1677657970, 799675597, 4052481195 } }, + // Matrix for nskip = 1 * 8 ^ 22: + { { 1827237091, 2290099491, 614471834 }, + { 3711385978, 2748163602, 2290099491 }, + { 2067064347, 1071954219, 2748163602 } }, + // Matrix for nskip = 2 * 8 ^ 22: + { { 3894793123, 921712152, 596236860 }, + { 4038673596, 4279784147, 921712152 }, + { 1999065039, 859801225, 4279784147 } }, + // Matrix for nskip = 3 * 8 ^ 22: + { { 3518731582, 2398700699, 3703766159 }, + { 1998914732, 1951351916, 2398700699 }, + { 2852188423, 1461089983, 1951351916 } }, + // Matrix for nskip = 4 * 8 ^ 22: + { { 7276915, 3205297712, 1204204130 }, + { 2667672243, 2737282292, 3205297712 }, + { 2282864144, 2305990443, 2737282292 } }, + // Matrix for nskip = 5 * 8 ^ 22: + { { 2376625824, 3090473348, 776691260 }, + { 4067754877, 2149314284, 3090473348 }, + { 198230411, 2870222545, 2149314284 } }, + // Matrix for nskip = 6 * 8 ^ 22: + { { 1638154181, 688311656, 278971912 }, + { 2626529484, 1769978612, 688311656 }, + { 2951434168, 1794042358, 1769978612 } }, + // Matrix for nskip = 7 * 8 ^ 22: + { { 3742216352, 1164158193, 1223269258 }, + { 3621125172, 3964660872, 1164158193 }, + { 3373873746, 2614176571, 3964660872 } }, + // Matrix for nskip = 1 * 8 ^ 23: + { { 935922304, 2428000499, 510672020 }, + { 1541887892, 92472822, 2428000499 }, + { 4146892220, 1307489118, 92472822 } }, + // Matrix for nskip = 2 * 8 ^ 23: + { { 690398653, 3787391292, 1705516721 }, + { 2953871718, 4173917861, 3787391292 }, + { 817556203, 3090114656, 4173917861 } }, + // Matrix for nskip = 3 * 8 ^ 23: + { { 2596837368, 523638114, 796925063 }, + { 2436421546, 3808361324, 523638114 }, + { 3645860436, 2767640965, 3808361324 } }, + // Matrix for nskip = 4 * 8 ^ 23: + { { 476867729, 1917800003, 1740083735 }, + { 3167988201, 1286715218, 1917800003 }, + { 2579365599, 4173763431, 1286715218 } }, + // Matrix for nskip = 5 * 8 ^ 23: + { { 875985265, 2902381003, 3585549348 }, + { 1487116735, 20494290, 2902381003 }, + { 3417450723, 672893019, 20494290 } }, + // Matrix for nskip = 6 * 8 ^ 23: + { { 680890926, 3782598365, 3927087723 }, + { 3291528625, 2096301120, 3782598365 }, + { 3927430411, 2089751145, 2096301120 } }, + // Matrix for nskip = 7 * 8 ^ 23: + { { 2506371881, 3282095953, 1709670308 }, + { 2778786590, 3316228403, 3282095953 }, + { 3936394935, 4103225131, 3316228403 } }, + // Matrix for nskip = 1 * 8 ^ 24: + { { 4092801160, 3749431174, 542781592 }, + { 1208313783, 217808460, 3749431174 }, + { 2708923752, 348848516, 217808460 } }, + // Matrix for nskip = 2 * 8 ^ 24: + { { 381829350, 1732869179, 3638540651 }, + { 2509789412, 1114357536, 1732869179 }, + { 2465372475, 350550480, 1114357536 } }, + // Matrix for nskip = 3 * 8 ^ 24: + { { 4088394360, 3507668274, 103212933 }, + { 1229010797, 2457049990, 3507668274 }, + { 1543332620, 3557973226, 2457049990 } }, + // Matrix for nskip = 4 * 8 ^ 24: + { { 3712059912, 1698887908, 3706277064 }, + { 2152325130, 232741719, 1698887908 }, + { 4114351745, 170237153, 232741719 } }, + // Matrix for nskip = 5 * 8 ^ 24: + { { 2230538189, 2798697140, 2813869207 }, + { 2098708615, 4247643355, 2798697140 }, + { 1732506223, 3352831267, 4247643355 } }, + // Matrix for nskip = 6 * 8 ^ 24: + { { 141104167, 950363290, 3526146168 }, + { 1842485244, 366288723, 950363290 }, + { 901263071, 3346018419, 366288723 } }, + // Matrix for nskip = 7 * 8 ^ 24: + { { 1273880950, 1252923554, 845609283 }, + { 3523638916, 1756558336, 1252923554 }, + { 983823623, 3396822999, 1756558336 } }, + // Matrix for nskip = 1 * 8 ^ 25: + { { 993804379, 905755330, 1717718779 }, + { 1712994855, 2713148271, 905755330 }, + { 2200585411, 111258429, 2713148271 } }, + // Matrix for nskip = 2 * 8 ^ 25: + { { 82758667, 1871391091, 4127413238 }, + { 3672831523, 69195019, 1871391091 }, + { 3672091415, 3528743235, 69195019 } }, + // Matrix for nskip = 3 * 8 ^ 25: + { { 1954591259, 636118602, 2621269238 }, + { 462961075, 4030630272, 636118602 }, + { 3305976356, 1757343588, 4030630272 } }, + // Matrix for nskip = 4 * 8 ^ 25: + { { 3361372532, 2329303404, 99651939 }, + { 2008671965, 2931758910, 2329303404 }, + { 1113529483, 2374097189, 2931758910 } }, + // Matrix for nskip = 5 * 8 ^ 25: + { { 1475330900, 1973232757, 3087886870 }, + { 1184427939, 3491162930, 1973232757 }, + { 4229179055, 3166281484, 3491162930 } }, + // Matrix for nskip = 6 * 8 ^ 25: + { { 2138712950, 3210181465, 230171794 }, + { 1011789944, 3536018417, 3210181465 }, + { 2847216174, 620673032, 3536018417 } }, + // Matrix for nskip = 7 * 8 ^ 25: + { { 1691375920, 1708800738, 1210582211 }, + { 2919192023, 1561934882, 1708800738 }, + { 3388931282, 2988640653, 1561934882 } }, + // Matrix for nskip = 1 * 8 ^ 26: + { { 1831590873, 1588259595, 1314332382 }, + { 2385989343, 2508077280, 1588259595 }, + { 1787615788, 661437137, 2508077280 } }, + // Matrix for nskip = 2 * 8 ^ 26: + { { 2326052247, 4183591379, 4049009082 }, + { 2604529491, 1453913233, 4183591379 }, + { 2311925423, 1805360390, 1453913233 } }, + // Matrix for nskip = 3 * 8 ^ 26: + { { 664423898, 2590401961, 4225456867 }, + { 3913458720, 1982184590, 2590401961 }, + { 2950459869, 334885555, 1982184590 } }, + // Matrix for nskip = 4 * 8 ^ 26: + { { 3956367490, 604461629, 1257432102 }, + { 794711716, 1155867175, 604461629 }, + { 1777070788, 429445904, 1155867175 } }, + // Matrix for nskip = 5 * 8 ^ 26: + { { 2357556007, 3027793563, 3037152168 }, + { 328118796, 419690250, 3027793563 }, + { 2699357594, 1143766272, 419690250 } }, + // Matrix for nskip = 6 * 8 ^ 26: + { { 3183717084, 2634631308, 2109777894 }, + { 1745049657, 2872637888, 2634631308 }, + { 3660634616, 2434030341, 2872637888 } }, + // Matrix for nskip = 7 * 8 ^ 26: + { { 961674331, 524745427, 3832393053 }, + { 2375268260, 2883640227, 524745427 }, + { 3564327755, 2130782725, 2883640227 } }, + // Matrix for nskip = 1 * 8 ^ 27: + { { 1686241617, 1257046062, 1427609439 }, + { 490376081, 387798431, 1257046062 }, + { 235551485, 1312672615, 387798431 } }, + // Matrix for nskip = 2 * 8 ^ 27: + { { 2362447880, 3445363024, 3160262066 }, + { 2426867845, 4194339866, 3445363024 }, + { 1046144413, 4177893681, 4194339866 } }, + // Matrix for nskip = 3 * 8 ^ 27: + { { 2399569099, 1723951785, 2356709199 }, + { 332901774, 3265509251, 1723951785 }, + { 3616767886, 1726850927, 3265509251 } }, + // Matrix for nskip = 4 * 8 ^ 27: + { { 4251175413, 3559576374, 3107663662 }, + { 697539134, 1909472435, 3559576374 }, + { 280754246, 375835695, 1909472435 } }, + // Matrix for nskip = 5 * 8 ^ 27: + { { 1441163739, 911930333, 4028966669 }, + { 3689446034, 1473406035, 911930333 }, + { 3884376669, 1954838782, 1473406035 } }, + // Matrix for nskip = 6 * 8 ^ 27: + { { 751906018, 4203984455, 2167450892 }, + { 3937403282, 1862670973, 4203984455 }, + { 402523958, 496211406, 1862670973 } }, + // Matrix for nskip = 7 * 8 ^ 27: + { { 726664456, 2233062609, 98234458 }, + { 149028817, 3613797222, 2233062609 }, + { 3848675801, 4164228265, 3613797222 } }, + // Matrix for nskip = 1 * 8 ^ 28: + { { 1099512970, 712404985, 1571467521 }, + { 546519870, 1135109300, 712404985 }, + { 3325312332, 2352874613, 1135109300 } }, + // Matrix for nskip = 2 * 8 ^ 28: + { { 1945425936, 1653045514, 381988982 }, + { 3733376326, 414410025, 1653045514 }, + { 1181583679, 1185848176, 414410025 } }, + // Matrix for nskip = 3 * 8 ^ 28: + { { 80175856, 1301935019, 1963289366 }, + { 3961455404, 65355284, 1301935019 }, + { 3052316027, 2858851708, 65355284 } }, + // Matrix for nskip = 4 * 8 ^ 28: + { { 2526336124, 3019211015, 4215964965 }, + { 2683163472, 4188191530, 3019211015 }, + { 2964651598, 293801056, 4188191530 } }, + // Matrix for nskip = 5 * 8 ^ 28: + { { 1749670132, 1387140872, 762351827 }, + { 2971687592, 1196758134, 1387140872 }, + { 237185264, 1741700121, 1196758134 } }, + // Matrix for nskip = 6 * 8 ^ 28: + { { 4238062407, 481737140, 1487069976 }, + { 878719633, 759707097, 481737140 }, + { 749051338, 825174423, 759707097 } }, + // Matrix for nskip = 7 * 8 ^ 28: + { { 1955913150, 1130524081, 2151646894 }, + { 499306218, 101202533, 1130524081 }, + { 2744191919, 1603656961, 101202533 } }, + // Matrix for nskip = 1 * 8 ^ 29: + { { 1444052678, 2253324417, 39719589 }, + { 1880267534, 2391992038, 2253324417 }, + { 987740265, 3691889508, 2391992038 } }, + // Matrix for nskip = 2 * 8 ^ 29: + { { 166599066, 2335494420, 1232261118 }, + { 2227597731, 2570600780, 2335494420 }, + { 2700034538, 3460843234, 2570600780 } }, + // Matrix for nskip = 3 * 8 ^ 29: + { { 391577970, 1926759295, 2700541692 }, + { 1952364431, 2281246481, 1926759295 }, + { 1819825140, 2377574285, 2281246481 } }, + // Matrix for nskip = 4 * 8 ^ 29: + { { 2511338360, 1188954576, 1251401239 }, + { 2511664974, 292276982, 1188954576 }, + { 697844082, 3093661552, 292276982 } }, + // Matrix for nskip = 5 * 8 ^ 29: + { { 3984792772, 688024800, 2775323178 }, + { 2263182715, 2971941970, 688024800 }, + { 3585402638, 2532257287, 2971941970 } }, + // Matrix for nskip = 6 * 8 ^ 29: + { { 1029044592, 4040666706, 3671213347 }, + { 902332253, 3140636559, 4040666706 }, + { 1429177194, 3213333408, 3140636559 } }, + // Matrix for nskip = 7 * 8 ^ 29: + { { 3635935824, 599310841, 2541542820 }, + { 1942681116, 83716008, 599310841 }, + { 2957115888, 464001685, 83716008 } }, + // Matrix for nskip = 1 * 8 ^ 30: + { { 3624650744, 51993077, 3540268009 }, + { 3252828938, 3710319575, 51993077 }, + { 2858628849, 3910069381, 3710319575 } }, + // Matrix for nskip = 2 * 8 ^ 30: + { { 655966702, 754002362, 1646581402 }, + { 1958331075, 475572423, 754002362 }, + { 3248619000, 3228514800, 475572423 } }, + // Matrix for nskip = 3 * 8 ^ 30: + { { 135820422, 1138672588, 1020827900 }, + { 626151178, 4149545048, 1138672588 }, + { 2180788629, 1314604300, 4149545048 } }, + // Matrix for nskip = 4 * 8 ^ 30: + { { 2760311307, 4166372813, 741596417 }, + { 2282679206, 3090782630, 4166372813 }, + { 3242468721, 1628442374, 3090782630 } }, + // Matrix for nskip = 5 * 8 ^ 30: + { { 88347075, 1420161828, 3113798953 }, + { 217224032, 2004343529, 1420161828 }, + { 4048389654, 3845790311, 2004343529 } }, + // Matrix for nskip = 6 * 8 ^ 30: + { { 4237022985, 912148655, 165387559 }, + { 252556101, 230998942, 912148655 }, + { 2978268820, 7678432, 230998942 } }, + // Matrix for nskip = 7 * 8 ^ 30: + { { 1702648282, 936444437, 2113813328 }, + { 2870633999, 384435053, 936444437 }, + { 2426580506, 1660785110, 384435053 } }, + // Matrix for nskip = 1 * 8 ^ 31: + { { 4265279407, 3532111852, 1754687396 }, + { 500404765, 2603727025, 3532111852 }, + { 1428367254, 3149485478, 2603727025 } }, + // Matrix for nskip = 2 * 8 ^ 31: + { { 2873769531, 2081104178, 596284397 }, + { 4153800443, 1261269623, 2081104178 }, + { 3967600061, 1830023157, 1261269623 } }, + // Matrix for nskip = 3 * 8 ^ 31: + { { 1219416476, 2833805942, 877956083 }, + { 4136201738, 926561185, 2833805942 }, + { 790563916, 2950279312, 926561185 } }, + // Matrix for nskip = 4 * 8 ^ 31: + { { 278611533, 2229285304, 3443204327 }, + { 3110641420, 77498444, 2229285304 }, + { 3904070810, 1070507239, 77498444 } }, + // Matrix for nskip = 5 * 8 ^ 31: + { { 1569490059, 1438273012, 1676406913 }, + { 2246148877, 835628171, 1438273012 }, + { 1001911068, 165198836, 835628171 } }, + // Matrix for nskip = 6 * 8 ^ 31: + { { 219341062, 236464123, 3922106376 }, + { 244990374, 2122146632, 236464123 }, + { 2065383788, 2977102789, 2122146632 } }, + // Matrix for nskip = 7 * 8 ^ 31: + { { 2250560481, 1729521343, 424414765 }, + { 2059608998, 3276353542, 1729521343 }, + { 2230558099, 3933677451, 3276353542 } }, + // Matrix for nskip = 1 * 8 ^ 32: + { { 544639534, 568528663, 2177189807 }, + { 2475829068, 121482268, 568528663 }, + { 876978915, 3116647617, 121482268 } }, + // Matrix for nskip = 2 * 8 ^ 32: + { { 1547862823, 2404658587, 4191448009 }, + { 2158188804, 2976916793, 2404658587 }, + { 168571747, 1691884706, 2976916793 } }, + // Matrix for nskip = 3 * 8 ^ 32: + { { 2707010111, 2933510859, 4240166566 }, + { 1177241360, 62338927, 2933510859 }, + { 2798158767, 906126073, 62338927 } }, + // Matrix for nskip = 4 * 8 ^ 32: + { { 3208213311, 4212638780, 3235157352 }, + { 671148556, 2951207765, 4212638780 }, + { 2075145516, 2395485231, 2951207765 } }, + // Matrix for nskip = 5 * 8 ^ 32: + { { 3757387996, 3349220842, 3722506196 }, + { 224784515, 2952700002, 3349220842 }, + { 1142378033, 2302905244, 2952700002 } }, + // Matrix for nskip = 6 * 8 ^ 32: + { { 1941283113, 145407649, 659394903 }, + { 347432419, 1571592397, 145407649 }, + { 2204145504, 3369375773, 1571592397 } }, + // Matrix for nskip = 7 * 8 ^ 32: + { { 1094854803, 386906095, 3767619826 }, + { 1281474767, 179198568, 386906095 }, + { 3021644798, 3594781674, 179198568 } }, + // Matrix for nskip = 1 * 8 ^ 33: + { { 4080517315, 2133433101, 4043998180 }, + { 2044221845, 867670560, 2133433101 }, + { 834432416, 3613001199, 867670560 } }, + // Matrix for nskip = 2 * 8 ^ 33: + { { 4102885735, 1319434267, 2678775073 }, + { 740092580, 607380970, 1319434267 }, + { 2198271844, 2610193258, 607380970 } }, + // Matrix for nskip = 3 * 8 ^ 33: + { { 2725610481, 764583647, 1059048169 }, + { 2571438051, 3510614410, 764583647 }, + { 1753866259, 3525435230, 3510614410 } }, + // Matrix for nskip = 4 * 8 ^ 33: + { { 1165218048, 1317690360, 1189150958 }, + { 399240205, 2507168618, 1317690360 }, + { 2988334517, 2687593413, 2507168618 } }, + // Matrix for nskip = 5 * 8 ^ 33: + { { 1160307294, 3843003921, 120011318 }, + { 1648569394, 2331840681, 3843003921 }, + { 2666551617, 1826785014, 2331840681 } }, + // Matrix for nskip = 6 * 8 ^ 33: + { { 2745374441, 3528536028, 2077936780 }, + { 3475527779, 16047360, 3528536028 }, + { 1346223401, 3691116188, 16047360 } }, + // Matrix for nskip = 7 * 8 ^ 33: + { { 3985894561, 4225395152, 3428831071 }, + { 3666024757, 3230532631, 4225395152 }, + { 2407932196, 4261187489, 3230532631 } }, + // Matrix for nskip = 1 * 8 ^ 34: + { { 1028861702, 4082006648, 338232527 }, + { 1888486946, 1842080991, 4082006648 }, + { 3903826366, 3109935091, 1842080991 } }, + // Matrix for nskip = 2 * 8 ^ 34: + { { 614134826, 2261996505, 2888080641 }, + { 710199359, 2773979788, 2261996505 }, + { 1144301620, 2554371815, 2773979788 } }, + // Matrix for nskip = 3 * 8 ^ 34: + { { 3872045348, 2988495416, 3084935324 }, + { 1788745968, 3505214566, 2988495416 }, + { 2741627244, 478558438, 3505214566 } }, + // Matrix for nskip = 4 * 8 ^ 34: + { { 4056173823, 1285620078, 357420018 }, + { 2423072612, 2309408315, 1285620078 }, + { 1533175115, 2760088020, 2309408315 } }, + // Matrix for nskip = 5 * 8 ^ 34: + { { 3469546091, 369086126, 3478496559 }, + { 3780710118, 589042104, 369086126 }, + { 1900191562, 3935275606, 589042104 } }, + // Matrix for nskip = 6 * 8 ^ 34: + { { 1682769046, 1059146837, 2627186100 }, + { 975501718, 2081627761, 1059146837 }, + { 4182902400, 2809990303, 2081627761 } }, + // Matrix for nskip = 7 * 8 ^ 34: + { { 3037332387, 2654288975, 181147870 }, + { 454223518, 808123674, 2654288975 }, + { 967475810, 1382885174, 808123674 } }, + // Matrix for nskip = 1 * 8 ^ 35: + { { 4264130267, 815015434, 3142242173 }, + { 180649975, 2500813569, 815015434 }, + { 3378723563, 829683767, 2500813569 } }, + // Matrix for nskip = 2 * 8 ^ 35: + { { 4174387531, 1030729435, 2812778314 }, + { 1752988797, 4044178729, 1030729435 }, + { 467969301, 554748104, 4044178729 } }, + // Matrix for nskip = 3 * 8 ^ 35: + { { 1224655671, 538480994, 911775489 }, + { 571730491, 1197428336, 538480994 }, + { 310254483, 3482088360, 1197428336 } }, + // Matrix for nskip = 4 * 8 ^ 35: + { { 1348429235, 2928743274, 3776082629 }, + { 3607529209, 3069812185, 2928743274 }, + { 2542432347, 3208181168, 3069812185 } }, + // Matrix for nskip = 5 * 8 ^ 35: + { { 2414375640, 2994139106, 1829200407 }, + { 3723068499, 3276234188, 2994139106 }, + { 1384068579, 3863982741, 3276234188 } }, + // Matrix for nskip = 6 * 8 ^ 35: + { { 798763723, 2897556757, 3145856482 }, + { 3421663444, 3946110585, 2897556757 }, + { 1853745554, 260368160, 3946110585 } }, + // Matrix for nskip = 7 * 8 ^ 35: + { { 95178102, 3740645591, 3060595950 }, + { 3321952562, 3932965485, 3740645591 }, + { 76660843, 2044406932, 3932965485 } }, + // Matrix for nskip = 1 * 8 ^ 36: + { { 4064845753, 668285756, 3816217625 }, + { 3713143233, 1380634204, 668285756 }, + { 3533700508, 1192551435, 1380634204 } }, + // Matrix for nskip = 2 * 8 ^ 36: + { { 1515684518, 1706771705, 728123349 }, + { 3174850469, 2057456462, 1706771705 }, + { 3410402985, 2897339640, 2057456462 } }, + // Matrix for nskip = 3 * 8 ^ 36: + { { 493252920, 4038063126, 2168451262 }, + { 363246278, 1249105026, 4038063126 }, + { 3395543717, 3358422070, 1249105026 } }, + // Matrix for nskip = 4 * 8 ^ 36: + { { 3082272717, 531091457, 1390161328 }, + { 3895139973, 2171402857, 531091457 }, + { 4030688141, 3049703400, 2171402857 } }, + // Matrix for nskip = 5 * 8 ^ 36: + { { 3935740675, 2355871533, 3949682718 }, + { 2931048320, 902295474, 2355871533 }, + { 847382876, 591758943, 902295474 } }, + // Matrix for nskip = 6 * 8 ^ 36: + { { 1096633558, 956915353, 71119600 }, + { 1282074175, 3814732591, 956915353 }, + { 1834617826, 3605659623, 3814732591 } }, + // Matrix for nskip = 7 * 8 ^ 36: + { { 1213485394, 883705085, 1819500595 }, + { 3547515338, 2658882772, 883705085 }, + { 3298597677, 2195730734, 2658882772 } }, + // Matrix for nskip = 1 * 8 ^ 37: + { { 1241147206, 3193892819, 1244284192 }, + { 65180262, 4065669017, 3193892819 }, + { 1484817937, 3661081858, 4065669017 } }, + // Matrix for nskip = 2 * 8 ^ 37: + { { 1438760812, 3491341751, 3414470157 }, + { 2805337292, 272266053, 3491341751 }, + { 824109230, 3202556526, 272266053 } }, + // Matrix for nskip = 3 * 8 ^ 37: + { { 3548908153, 1458259435, 2902555273 }, + { 3865796034, 2523447078, 1458259435 }, + { 2359984375, 3898395136, 2523447078 } }, + // Matrix for nskip = 4 * 8 ^ 37: + { { 135412706, 3627115412, 2345042216 }, + { 1565169824, 2166856449, 3627115412 }, + { 1026946745, 3467845248, 2166856449 } }, + // Matrix for nskip = 5 * 8 ^ 37: + { { 4146693931, 4048659004, 2768049120 }, + { 2555866488, 2548281288, 4048659004 }, + { 2954738533, 4242463239, 2548281288 } }, + // Matrix for nskip = 6 * 8 ^ 37: + { { 1796100563, 2291501743, 3432007410 }, + { 1204345078, 1110795947, 2291501743 }, + { 3388382946, 3937816720, 1110795947 } }, + // Matrix for nskip = 7 * 8 ^ 37: + { { 3208221515, 607811602, 223757102 }, + { 377063363, 3323143974, 607811602 }, + { 279359428, 3272907713, 3323143974 } }, + // Matrix for nskip = 1 * 8 ^ 38: + { { 1889419951, 3256876154, 1240505488 }, + { 1254783743, 989966800, 3256876154 }, + { 1995297400, 3692472918, 989966800 } }, + // Matrix for nskip = 2 * 8 ^ 38: + { { 3206226875, 285700890, 496017472 }, + { 2515316194, 2129675196, 285700890 }, + { 1863853990, 2673457552, 2129675196 } }, + // Matrix for nskip = 3 * 8 ^ 38: + { { 2643396669, 1141176790, 2183048631 }, + { 2796763418, 686457718, 1141176790 }, + { 3473541724, 755015447, 686457718 } }, + // Matrix for nskip = 4 * 8 ^ 38: + { { 4163770641, 255160418, 772100749 }, + { 1987092456, 3237660221, 255160418 }, + { 1394381051, 4216039401, 3237660221 } }, + // Matrix for nskip = 5 * 8 ^ 38: + { { 2744038617, 4151599085, 1086739611 }, + { 2137012024, 1231067556, 4151599085 }, + { 2054217062, 1474724988, 1231067556 } }, + // Matrix for nskip = 6 * 8 ^ 38: + { { 1966926556, 2167105562, 3642406633 }, + { 3575908026, 76072334, 2167105562 }, + { 438275780, 1024705325, 76072334 } }, + // Matrix for nskip = 7 * 8 ^ 38: + { { 3144149631, 1078973412, 1395133864 }, + { 1200101371, 2263842276, 1078973412 }, + { 1990245354, 4126971783, 2263842276 } }, + // Matrix for nskip = 1 * 8 ^ 39: + { { 2133915627, 2713747584, 627765421 }, + { 2300605925, 35690583, 2713747584 }, + { 2918902946, 2638220304, 35690583 } }, + // Matrix for nskip = 2 * 8 ^ 39: + { { 2587549655, 998684270, 4292130625 }, + { 1791772791, 2820705344, 998684270 }, + { 124590158, 3831143549, 2820705344 } }, + // Matrix for nskip = 3 * 8 ^ 39: + { { 3910080826, 1802646553, 3446926966 }, + { 129865302, 1755670478, 1802646553 }, + { 1006007080, 2257707516, 1755670478 } }, + // Matrix for nskip = 4 * 8 ^ 39: + { { 978482299, 3200877282, 497605289 }, + { 3717741518, 3737164414, 3200877282 }, + { 4046686626, 861393946, 3737164414 } }, + // Matrix for nskip = 5 * 8 ^ 39: + { { 3183253558, 201453184, 3145469059 }, + { 3983740037, 3717279042, 201453184 }, + { 976459397, 485566112, 3717279042 } }, + // Matrix for nskip = 6 * 8 ^ 39: + { { 1649247358, 1293997566, 1141681757 }, + { 2104529013, 3994478979, 1293997566 }, + { 12048398, 1296267255, 3994478979 } }, + // Matrix for nskip = 7 * 8 ^ 39: + { { 1277127010, 3409985649, 2357026796 }, + { 546146378, 1239287374, 3409985649 }, + { 684416427, 1435662521, 1239287374 } }, + // Matrix for nskip = 1 * 8 ^ 40: + { { 2665561897, 300934584, 3179822945 }, + { 893043137, 2031413512, 300934584 }, + { 3806926970, 2413249929, 2031413512 } }, + // Matrix for nskip = 2 * 8 ^ 40: + { { 1417581911, 3071835354, 2575196237 }, + { 4101127251, 1375339216, 3071835354 }, + { 847617977, 3632503316, 1375339216 } }, + // Matrix for nskip = 3 * 8 ^ 40: + { { 608673033, 22126256, 3556899267 }, + { 1727979207, 849327659, 22126256 }, + { 1702248031, 791369590, 849327659 } }, + // Matrix for nskip = 4 * 8 ^ 40: + { { 2747488994, 3296604805, 898095468 }, + { 1742777145, 219265369, 3296604805 }, + { 823714885, 667779292, 219265369 } }, + // Matrix for nskip = 5 * 8 ^ 40: + { { 2021014596, 471433423, 2651735970 }, + { 585977516, 1605468910, 471433423 }, + { 549943099, 3890474462, 1605468910 } }, + // Matrix for nskip = 6 * 8 ^ 40: + { { 3574350911, 1933183379, 2250823873 }, + { 1024311233, 365568357, 1933183379 }, + { 3430128519, 3029426194, 365568357 } }, + // Matrix for nskip = 7 * 8 ^ 40: + { { 1074178830, 2265105869, 2758013402 }, + { 4125786414, 1034741107, 2265105869 }, + { 1441524697, 2229554511, 1034741107 } }, + // Matrix for nskip = 1 * 8 ^ 41: + { { 2640209692, 3040506537, 3626115220 }, + { 161827078, 852668118, 3040506537 }, + { 3856381322, 3360242076, 852668118 } }, + // Matrix for nskip = 2 * 8 ^ 41: + { { 3734246393, 4151553160, 4177051283 }, + { 266522866, 1731798531, 4151553160 }, + { 632196679, 3864297722, 1731798531 } }, + // Matrix for nskip = 3 * 8 ^ 41: + { { 688933188, 355423319, 287306155 }, + { 1805598431, 3402169658, 355423319 }, + { 2000267685, 2145558314, 3402169658 } }, + // Matrix for nskip = 4 * 8 ^ 41: + { { 1694175127, 1087914338, 2384195794 }, + { 2764925057, 505782858, 1087914338 }, + { 3235634082, 807915248, 505782858 } }, + // Matrix for nskip = 5 * 8 ^ 41: + { { 993693315, 3946332366, 3916271739 }, + { 1789813323, 4018933334, 3946332366 }, + { 441058505, 3553235314, 4018933334 } }, + // Matrix for nskip = 6 * 8 ^ 41: + { { 1144818794, 3134263190, 1846865568 }, + { 1502689349, 1628360471, 3134263190 }, + { 745146577, 1872576407, 1628360471 } }, + // Matrix for nskip = 7 * 8 ^ 41: + { { 3398717147, 3990568019, 892329010 }, + { 3847547913, 3198332877, 3990568019 }, + { 333749571, 1549630885, 3198332877 } }, + // Matrix for nskip = 1 * 8 ^ 42: + { { 2402749950, 2353776151, 75909174 }, + { 890570951, 1752665661, 2353776151 }, + { 3120241607, 3862435696, 1752665661 } }, + // Matrix for nskip = 2 * 8 ^ 42: + { { 2427906178, 3580155704, 949770784 }, + { 226153695, 1230515664, 3580155704 }, + { 1988835001, 986791581, 1230515664 } }, + // Matrix for nskip = 3 * 8 ^ 42: + { { 2162922488, 4037183513, 346268022 }, + { 2752767565, 2852643415, 4037183513 }, + { 3557895539, 3796282786, 2852643415 } }, + // Matrix for nskip = 4 * 8 ^ 42: + { { 1774047142, 3199155377, 3106427820 }, + { 1901920839, 4290900039, 3199155377 }, + { 4178980191, 280623348, 4290900039 } }, + // Matrix for nskip = 5 * 8 ^ 42: + { { 564504637, 3960126556, 13271050 }, + { 3975695622, 272607318, 3960126556 }, + { 1199282733, 981722530, 272607318 } }, + // Matrix for nskip = 6 * 8 ^ 42: + { { 3723690896, 3153461912, 693938118 }, + { 2676196226, 1636264737, 3153461912 }, + { 764380249, 3364804206, 1636264737 } }, + // Matrix for nskip = 7 * 8 ^ 42: + { { 2002746065, 838117661, 347920205 }, + { 3311479485, 2381255152, 838117661 }, + { 4107898714, 2782779087, 2381255152 } }, + // Matrix for nskip = 1 * 8 ^ 43: + { { 3567524348, 1934119675, 3188270128 }, + { 2997767678, 826363896, 1934119675 }, + { 262952343, 614326610, 826363896 } }, + // Matrix for nskip = 2 * 8 ^ 43: + { { 1625613062, 4288164505, 2481284279 }, + { 4273461426, 1177260757, 4288164505 }, + { 305959988, 4017252267, 1177260757 } }, + // Matrix for nskip = 3 * 8 ^ 43: + { { 3536417809, 429648601, 2955466274 }, + { 1272075175, 3057838997, 429648601 }, + { 2269698346, 4011682346, 3057838997 } }, + // Matrix for nskip = 4 * 8 ^ 43: + { { 337929267, 333342539, 418300166 }, + { 2944208672, 379097734, 333342539 }, + { 2084056909, 3625475947, 379097734 } }, + // Matrix for nskip = 5 * 8 ^ 43: + { { 68058625, 1918117806, 635887182 }, + { 1946098288, 2963456150, 1918117806 }, + { 2625600235, 2337231210, 2963456150 } }, + // Matrix for nskip = 6 * 8 ^ 43: + { { 1700493457, 3627573759, 545164662 }, + { 1921927973, 1170497671, 3627573759 }, + { 3094336698, 2906222607, 1170497671 } }, + // Matrix for nskip = 7 * 8 ^ 43: + { { 575329368, 1216196496, 4089812320 }, + { 2113496301, 1220844336, 1216196496 }, + { 3926254763, 817590918, 1220844336 } }, + // Matrix for nskip = 1 * 8 ^ 44: + { { 1189899255, 1307754719, 1214919992 }, + { 3736721708, 3514751918, 1307754719 }, + { 732435953, 2021244538, 3514751918 } }, + // Matrix for nskip = 2 * 8 ^ 44: + { { 4089172695, 1533534334, 525643282 }, + { 1497577018, 1335684482, 1533534334 }, + { 2079007086, 3977541427, 1335684482 } }, + // Matrix for nskip = 3 * 8 ^ 44: + { { 851614119, 2992100005, 2852461785 }, + { 2850360626, 2514447281, 2992100005 }, + { 978015612, 1397973230, 2514447281 } }, + // Matrix for nskip = 4 * 8 ^ 44: + { { 3075256652, 2762754934, 3846844247 }, + { 3057872364, 3274545167, 2762754934 }, + { 4028573983, 938934351, 3274545167 } }, + // Matrix for nskip = 5 * 8 ^ 44: + { { 1356476668, 2626409409, 1479462144 }, + { 1188404397, 1260428167, 2626409409 }, + { 3595448064, 2360949430, 1260428167 } }, + // Matrix for nskip = 6 * 8 ^ 44: + { { 1027674032, 887967109, 3655047107 }, + { 3381172536, 2839247420, 887967109 }, + { 1109942153, 1231881661, 2839247420 } }, + // Matrix for nskip = 7 * 8 ^ 44: + { { 3084422684, 3716427472, 3899800153 }, + { 2713114448, 2433847057, 3716427472 }, + { 2089286798, 4032596403, 2433847057 } }, + // Matrix for nskip = 1 * 8 ^ 45: + { { 2597859300, 2880151048, 2523330453 }, + { 1121709186, 175667448, 2880151048 }, + { 4182510911, 1723133625, 175667448 } }, + // Matrix for nskip = 2 * 8 ^ 45: + { { 484148868, 1404283933, 2982534313 }, + { 3736767353, 3179865161, 1404283933 }, + { 391120388, 3758716888, 3179865161 } }, + // Matrix for nskip = 3 * 8 ^ 45: + { { 3773686289, 1118146915, 4257811308 }, + { 2626215981, 2155767823, 1118146915 }, + { 4216113535, 234812272, 2155767823 } }, + // Matrix for nskip = 4 * 8 ^ 45: + { { 2138867468, 1128973399, 2133702321 }, + { 1613561693, 3622350766, 1128973399 }, + { 1500151924, 3759983985, 3622350766 } }, + // Matrix for nskip = 5 * 8 ^ 45: + { { 2098219600, 3500149955, 509598935 }, + { 3938592198, 2627573355, 3500149955 }, + { 2296762399, 2144538279, 2627573355 } }, + // Matrix for nskip = 6 * 8 ^ 45: + { { 1272813809, 709982328, 2430723917 }, + { 3808746634, 1052744045, 709982328 }, + { 346250782, 2541155134, 1052744045 } }, + // Matrix for nskip = 7 * 8 ^ 45: + { { 959495863, 240812937, 1778012651 }, + { 803153186, 1920219267, 240812937 }, + { 2528085623, 422007, 1920219267 } }, + // Matrix for nskip = 1 * 8 ^ 46: + { { 3027706760, 3786576552, 2698781808 }, + { 2810527099, 90498489, 3786576552 }, + { 4220122612, 1855245979, 90498489 } }, + // Matrix for nskip = 2 * 8 ^ 46: + { { 3739389517, 1110440720, 917457922 }, + { 2163873618, 3707591763, 1110440720 }, + { 2667061910, 2533383962, 3707591763 } }, + // Matrix for nskip = 3 * 8 ^ 46: + { { 3440567542, 213023128, 821316937 }, + { 1289665822, 1120982854, 213023128 }, + { 1107018173, 2157902557, 1120982854 } }, + // Matrix for nskip = 4 * 8 ^ 46: + { { 1545226000, 1812182123, 3693349190 }, + { 3422065122, 3291428549, 1812182123 }, + { 1193168720, 2072837757, 3291428549 } }, + // Matrix for nskip = 5 * 8 ^ 46: + { { 1411838727, 1497286518, 2743320941 }, + { 1476608684, 3759942398, 1497286518 }, + { 3033567880, 1132137328, 3759942398 } }, + // Matrix for nskip = 6 * 8 ^ 46: + { { 4164586694, 3847046376, 939466538 }, + { 455920568, 1287777429, 3847046376 }, + { 2394981758, 891603161, 1287777429 } }, + // Matrix for nskip = 7 * 8 ^ 46: + { { 3992667160, 390631011, 4070853162 }, + { 1146538952, 1264300453, 390631011 }, + { 2489808111, 407533173, 1264300453 } }, + // Matrix for nskip = 1 * 8 ^ 47: + { { 3230096243, 2131723358, 3262178024 }, + { 2882890127, 4088518247, 2131723358 }, + { 3991553306, 1282224087, 4088518247 } }, + // Matrix for nskip = 2 * 8 ^ 47: + { { 301207261, 1722796810, 3697719854 }, + { 3350228505, 3410986694, 1722796810 }, + { 3684514720, 2846958957, 3410986694 } }, + // Matrix for nskip = 3 * 8 ^ 47: + { { 3625524738, 3319692776, 3795749903 }, + { 1715640681, 1890913372, 3319692776 }, + { 225727143, 928307593, 1890913372 } }, + // Matrix for nskip = 4 * 8 ^ 47: + { { 1532963114, 4236235786, 3871128158 }, + { 3540401964, 1285250577, 4236235786 }, + { 1105070646, 2764245175, 1285250577 } }, + // Matrix for nskip = 5 * 8 ^ 47: + { { 3740934706, 2563937648, 2746910512 }, + { 3298575982, 2047742419, 2563937648 }, + { 654443081, 2109897740, 2047742419 } }, + // Matrix for nskip = 6 * 8 ^ 47: + { { 1240524792, 1728254085, 119873755 }, + { 1505600996, 2604901554, 1728254085 }, + { 3134968130, 2798059827, 2604901554 } }, + // Matrix for nskip = 7 * 8 ^ 47: + { { 1468859634, 1067606885, 482418964 }, + { 2025997689, 632183943, 1067606885 }, + { 152578308, 2630662559, 632183943 } }, + // Matrix for nskip = 1 * 8 ^ 48: + { { 210906218, 3068599594, 3034582784 }, + { 340633153, 4004365908, 3068599594 }, + { 4238928187, 2299166464, 4004365908 } }, + // Matrix for nskip = 2 * 8 ^ 48: + { { 2274701639, 3955606166, 3081246407 }, + { 3199954992, 3948054919, 3955606166 }, + { 2399101442, 3438340286, 3948054919 } }, + // Matrix for nskip = 3 * 8 ^ 48: + { { 1699759143, 4037535932, 1219209632 }, + { 633837171, 3333667032, 4037535932 }, + { 1309772249, 2404397407, 3333667032 } }, + // Matrix for nskip = 4 * 8 ^ 48: + { { 504137100, 1182303684, 201533985 }, + { 4188299661, 3042453580, 1182303684 }, + { 2578519273, 2674782930, 3042453580 } }, + // Matrix for nskip = 5 * 8 ^ 48: + { { 592752793, 2717374630, 1743344011 }, + { 1375705778, 3320840707, 2717374630 }, + { 128640966, 3026546742, 3320840707 } }, + // Matrix for nskip = 6 * 8 ^ 48: + { { 1370637124, 3074764013, 228550476 }, + { 1199760826, 3450980261, 3074764013 }, + { 1618563336, 1054833852, 3450980261 } }, + // Matrix for nskip = 7 * 8 ^ 48: + { { 1611431067, 3710031515, 2854732050 }, + { 528870942, 2907234375, 3710031515 }, + { 3445439485, 1092238667, 2907234375 } }, + // Matrix for nskip = 1 * 8 ^ 49: + { { 1382964588, 2578452047, 3140440866 }, + { 261861891, 1076783073, 2578452047 }, + { 1634588989, 164438428, 1076783073 } }, + // Matrix for nskip = 2 * 8 ^ 49: + { { 2529186343, 526867394, 3102803247 }, + { 2687252475, 2908898908, 526867394 }, + { 1213100579, 86050422, 2908898908 } }, + // Matrix for nskip = 3 * 8 ^ 49: + { { 1961703304, 2865880716, 3245956893 }, + { 2618763101, 2785604515, 2865880716 }, + { 2898900229, 1099125661, 2785604515 } }, + // Matrix for nskip = 4 * 8 ^ 49: + { { 2690118316, 538108523, 790337895 }, + { 4193870709, 1053552056, 538108523 }, + { 1635227281, 4002399925, 1053552056 } }, + // Matrix for nskip = 5 * 8 ^ 49: + { { 746488794, 2143647216, 1919679021 }, + { 3920176380, 1994557046, 2143647216 }, + { 661950432, 921383941, 1994557046 } }, + // Matrix for nskip = 6 * 8 ^ 49: + { { 1934635577, 2678342194, 4048456688 }, + { 3769235275, 3122368790, 2678342194 }, + { 3794884445, 2578750044, 3122368790 } }, + // Matrix for nskip = 7 * 8 ^ 49: + { { 2345462407, 3273239577, 504673677 }, + { 2663769112, 483235505, 3273239577 }, + { 2863427199, 2990731351, 483235505 } }, + // Matrix for nskip = 1 * 8 ^ 50: + { { 2123712957, 4205383007, 1812304090 }, + { 1095349745, 166243972, 4205383007 }, + { 428569070, 2128782357, 166243972 } }, + // Matrix for nskip = 2 * 8 ^ 50: + { { 1330151766, 3569679412, 4107175982 }, + { 3808641551, 3621125056, 3569679412 }, + { 4262164578, 1927692878, 3621125056 } }, + // Matrix for nskip = 3 * 8 ^ 50: + { { 4091558631, 3732834681, 466628750 }, + { 297727134, 2456485740, 3732834681 }, + { 1818617085, 834096815, 2456485740 } }, + // Matrix for nskip = 4 * 8 ^ 50: + { { 3606295184, 2442739556, 3894922338 }, + { 1629626641, 2729678535, 2442739556 }, + { 3379124758, 4279360935, 2729678535 } }, + // Matrix for nskip = 5 * 8 ^ 50: + { { 3518339108, 1807718360, 3760359041 }, + { 3698267057, 3466970024, 1807718360 }, + { 3728530930, 3457548085, 3466970024 } }, + // Matrix for nskip = 6 * 8 ^ 50: + { { 2193444679, 408556626, 3012130337 }, + { 1097569863, 59894341, 408556626 }, + { 3860432799, 476070138, 59894341 } }, + // Matrix for nskip = 7 * 8 ^ 50: + { { 1063004122, 547821813, 3531749039 }, + { 3513263202, 1281130561, 547821813 }, + { 3768689719, 180869393, 1281130561 } }, + // Matrix for nskip = 1 * 8 ^ 51: + { { 1052092278, 4249024666, 919210106 }, + { 3253349463, 3629539480, 4249024666 }, + { 852514024, 4025926501, 3629539480 } }, + // Matrix for nskip = 2 * 8 ^ 51: + { { 12394571, 1252747620, 2133571953 }, + { 4227339509, 3197545170, 1252747620 }, + { 1884529704, 1976203831, 3197545170 } }, + // Matrix for nskip = 3 * 8 ^ 51: + { { 2331594780, 452832640, 1101195955 }, + { 2939334015, 2029416251, 452832640 }, + { 1096100666, 3366782607, 2029416251 } }, + // Matrix for nskip = 4 * 8 ^ 51: + { { 2986331025, 2671019282, 2847338542 }, + { 3173738401, 3542657885, 2671019282 }, + { 745203060, 1546667401, 3542657885 } }, + // Matrix for nskip = 5 * 8 ^ 51: + { { 3475245690, 1308019352, 1824121179 }, + { 2721990050, 584665331, 1308019352 }, + { 935407479, 3072929538, 584665331 } }, + // Matrix for nskip = 6 * 8 ^ 51: + { { 1254243785, 987948282, 836901607 }, + { 2154496016, 3293370693, 987948282 }, + { 2487351160, 2120370930, 3293370693 } }, + // Matrix for nskip = 7 * 8 ^ 51: + { { 614238014, 976296831, 2444588607 }, + { 3245218993, 99887253, 976296831 }, + { 4012293175, 407199536, 99887253 } }, + // Matrix for nskip = 1 * 8 ^ 52: + { { 2613012997, 2311336951, 2911336433 }, + { 1493974713, 92565032, 2311336951 }, + { 2786645250, 257065974, 92565032 } }, + // Matrix for nskip = 2 * 8 ^ 52: + { { 3424925004, 2776053372, 2204068573 }, + { 3770626858, 2509257810, 2776053372 }, + { 2979919489, 1146336783, 2509257810 } }, + // Matrix for nskip = 3 * 8 ^ 52: + { { 2499905758, 2215361770, 3750482090 }, + { 1105380130, 3511408930, 2215361770 }, + { 634471839, 2666607166, 3511408930 } }, + // Matrix for nskip = 4 * 8 ^ 52: + { { 1474384834, 827894421, 515339473 }, + { 1373055755, 1949809417, 827894421 }, + { 3088339524, 1194193824, 1949809417 } }, + // Matrix for nskip = 5 * 8 ^ 52: + { { 811682426, 1464831324, 673124742 }, + { 1737209131, 4147063048, 1464831324 }, + { 104747063, 352467977, 4147063048 } }, + // Matrix for nskip = 6 * 8 ^ 52: + { { 1759193844, 2367252271, 658497461 }, + { 2079352492, 183217259, 2367252271 }, + { 4048695575, 533708602, 183217259 } }, + // Matrix for nskip = 7 * 8 ^ 52: + { { 2604083920, 2202319015, 2821035593 }, + { 3199388318, 109366125, 2202319015 }, + { 552179285, 3360277248, 109366125 } }, + // Matrix for nskip = 1 * 8 ^ 53: + { { 1825805135, 1289872272, 3700877161 }, + { 3433422861, 4062509844, 1289872272 }, + { 3019008744, 2060641859, 4062509844 } }, + // Matrix for nskip = 2 * 8 ^ 53: + { { 3842597153, 4253338264, 3424495942 }, + { 698444416, 60268595, 4253338264 }, + { 4096010585, 47309624, 60268595 } }, + // Matrix for nskip = 3 * 8 ^ 53: + { { 496690861, 2839992631, 523849894 }, + { 3748568076, 1725353677, 2839992631 }, + { 1590121940, 1652142356, 1725353677 } }, + // Matrix for nskip = 4 * 8 ^ 53: + { { 2662288323, 2043518992, 1593435980 }, + { 1330201507, 3618850300, 2043518992 }, + { 2538793204, 271787962, 3618850300 } }, + // Matrix for nskip = 5 * 8 ^ 53: + { { 3290637626, 1877437091, 683414954 }, + { 297749, 1492496540, 1877437091 }, + { 2568049682, 3340892636, 1492496540 } }, + // Matrix for nskip = 6 * 8 ^ 53: + { { 1177494705, 170978053, 1258089776 }, + { 175903832, 2352110692, 170978053 }, + { 3367780341, 265547447, 2352110692 } }, + // Matrix for nskip = 7 * 8 ^ 53: + { { 4000259518, 1585853138, 1894954679 }, + { 4025122327, 1695479283, 1585853138 }, + { 2854628986, 489784443, 1695479283 } }, + // Matrix for nskip = 1 * 8 ^ 54: + { { 741020448, 997594656, 2398808739 }, + { 1160477043, 1522130854, 997594656 }, + { 3036916315, 2847712653, 1522130854 } }, + // Matrix for nskip = 2 * 8 ^ 54: + { { 2654964886, 1889728930, 53329096 }, + { 2042322941, 1621136330, 1889728930 }, + { 1553642730, 784545882, 1621136330 } }, + // Matrix for nskip = 3 * 8 ^ 54: + { { 900526416, 798626824, 3879214027 }, + { 2219774094, 2513781045, 798626824 }, + { 1455564465, 3987302058, 2513781045 } }, + // Matrix for nskip = 4 * 8 ^ 54: + { { 1715219514, 2831829177, 929124824 }, + { 997274536, 404228189, 2831829177 }, + { 1386575385, 4107238699, 404228189 } }, + // Matrix for nskip = 5 * 8 ^ 54: + { { 3216180354, 346253769, 2204236686 }, + { 620690291, 2037367915, 346253769 }, + { 1423172488, 2780020913, 2037367915 } }, + // Matrix for nskip = 6 * 8 ^ 54: + { { 1361559514, 2840866920, 2161766692 }, + { 3777816531, 4291736115, 2840866920 }, + { 1449118903, 455358549, 4291736115 } }, + // Matrix for nskip = 7 * 8 ^ 54: + { { 3361155093, 1442101330, 2915072798 }, + { 270047328, 973080601, 1442101330 }, + { 2538519465, 2830816977, 973080601 } }, + // Matrix for nskip = 1 * 8 ^ 55: + { { 3928131551, 2912523524, 1840499723 }, + { 4216003022, 2970489088, 2912523524 }, + { 1158689953, 1425511081, 2970489088 } }, + // Matrix for nskip = 2 * 8 ^ 55: + { { 2807004452, 2510299562, 271603006 }, + { 2505735035, 2370490899, 2510299562 }, + { 10873814, 2450376936, 2370490899 } }, + // Matrix for nskip = 3 * 8 ^ 55: + { { 895842640, 1513759891, 652184790 }, + { 337719276, 3793171443, 1513759891 }, + { 661495819, 1882293939, 3793171443 } }, + // Matrix for nskip = 4 * 8 ^ 55: + { { 2000734342, 1113679064, 2502160539 }, + { 1475266926, 2787925323, 1113679064 }, + { 1475797635, 3044470744, 2787925323 } }, + // Matrix for nskip = 5 * 8 ^ 55: + { { 1766616799, 722317846, 1586650055 }, + { 1016766460, 76599155, 722317846 }, + { 2574759301, 623201703, 76599155 } }, + // Matrix for nskip = 6 * 8 ^ 55: + { { 3664739404, 4014926443, 1080154168 }, + { 2495955387, 1724853627, 4014926443 }, + { 536042925, 1256783759, 1724853627 } }, + // Matrix for nskip = 7 * 8 ^ 55: + { { 4046813655, 3373283605, 3767126799 }, + { 1560329332, 2618021767, 3373283605 }, + { 527165723, 2030169433, 2618021767 } }, + // Matrix for nskip = 1 * 8 ^ 56: + { { 1457157056, 1252556678, 3073232607 }, + { 1926798761, 3639907189, 1252556678 }, + { 2067740348, 2256217204, 3639907189 } }, + // Matrix for nskip = 2 * 8 ^ 56: + { { 3740999688, 1035400458, 3162437311 }, + { 4126312242, 686702830, 1035400458 }, + { 1699805291, 667792040, 686702830 } }, + // Matrix for nskip = 3 * 8 ^ 56: + { { 1345468819, 1338322079, 817781640 }, + { 2710885009, 1935673443, 1338322079 }, + { 877889863, 2304324596, 1935673443 } }, + // Matrix for nskip = 4 * 8 ^ 56: + { { 2422495016, 3203768688, 1858240466 }, + { 848719394, 4092709154, 3203768688 }, + { 659945473, 1863075174, 4092709154 } }, + // Matrix for nskip = 5 * 8 ^ 56: + { { 21345609, 2944772441, 1446242483 }, + { 3854092115, 3931174287, 2944772441 }, + { 3818334033, 340393141, 3931174287 } }, + // Matrix for nskip = 6 * 8 ^ 56: + { { 2472609977, 1572317229, 2146084483 }, + { 386210076, 1579232146, 1572317229 }, + { 3154153453, 3349077947, 1579232146 } }, + // Matrix for nskip = 7 * 8 ^ 56: + { { 3934658083, 1547798902, 578076866 }, + { 3707114992, 1649964845, 1547798902 }, + { 3740686873, 217906160, 1649964845 } }, + // Matrix for nskip = 1 * 8 ^ 57: + { { 246817944, 871751352, 2834051003 }, + { 3976202597, 3721214025, 871751352 }, + { 783929942, 745295675, 3721214025 } }, + // Matrix for nskip = 2 * 8 ^ 57: + { { 3811740424, 3603608092, 2365398362 }, + { 3826150877, 2906557036, 3603608092 }, + { 2300510686, 966815948, 2906557036 } }, + // Matrix for nskip = 3 * 8 ^ 57: + { { 2004086842, 752045049, 1443259442 }, + { 4222485982, 2275171478, 752045049 }, + { 959250674, 2731257760, 2275171478 } }, + // Matrix for nskip = 4 * 8 ^ 57: + { { 2816329160, 18201123, 3367710570 }, + { 437309679, 2220769388, 18201123 }, + { 1346863388, 705296543, 2220769388 } }, + // Matrix for nskip = 5 * 8 ^ 57: + { { 3868848671, 3006483395, 3903615747 }, + { 1680524656, 2885742075, 3006483395 }, + { 796648897, 2121364560, 2885742075 } }, + // Matrix for nskip = 6 * 8 ^ 57: + { { 2743985808, 1183199523, 686976485 }, + { 3080242732, 497836434, 1183199523 }, + { 2146196184, 523073130, 497836434 } }, + // Matrix for nskip = 7 * 8 ^ 57: + { { 281969912, 3168583843, 3387530534 }, + { 3604375441, 89658761, 3168583843 }, + { 3122537866, 3405552447, 89658761 } }, + // Matrix for nskip = 1 * 8 ^ 58: + { { 3310028953, 1662315499, 132645114 }, + { 2572908401, 3105849797, 1662315499 }, + { 1937586849, 1735620028, 3105849797 } }, + // Matrix for nskip = 2 * 8 ^ 58: + { { 461386353, 1359675853, 3599822966 }, + { 106675209, 2044154050, 1359675853 }, + { 1787730088, 1149892630, 2044154050 } }, + // Matrix for nskip = 3 * 8 ^ 58: + { { 1678397435, 2034254929, 404593054 }, + { 308885052, 4143854702, 2034254929 }, + { 1276625905, 1557265403, 4143854702 } }, + // Matrix for nskip = 4 * 8 ^ 58: + { { 3303902397, 345146034, 1417149696 }, + { 2231869247, 1116882637, 345146034 }, + { 1846832385, 79626976, 1116882637 } }, + // Matrix for nskip = 5 * 8 ^ 58: + { { 3163825854, 3437355918, 3790302358 }, + { 2966738005, 405418248, 3437355918 }, + { 2935909124, 1737823953, 405418248 } }, + // Matrix for nskip = 6 * 8 ^ 58: + { { 4188280456, 4245794318, 2115856958 }, + { 3899866941, 2230248511, 4245794318 }, + { 4151131385, 1810874924, 2230248511 } }, + // Matrix for nskip = 7 * 8 ^ 58: + { { 3183442289, 2647800101, 3155584995 }, + { 1803347712, 3081729031, 2647800101 }, + { 344634507, 408464888, 3081729031 } }, + // Matrix for nskip = 1 * 8 ^ 59: + { { 2765049417, 3117782790, 1805260159 }, + { 3796182890, 1101141726, 3117782790 }, + { 224270120, 1004001443, 1101141726 } }, + // Matrix for nskip = 2 * 8 ^ 59: + { { 89118668, 2494198515, 1356989069 }, + { 2490435731, 997151755, 2494198515 }, + { 1175528637, 3444341166, 997151755 } }, + // Matrix for nskip = 3 * 8 ^ 59: + { { 2610383359, 3160454394, 1595264559 }, + { 613651010, 1733540130, 3160454394 }, + { 1119988193, 1810350755, 1733540130 } }, + // Matrix for nskip = 4 * 8 ^ 59: + { { 2340639019, 510225634, 286119182 }, + { 2045217287, 1194574818, 510225634 }, + { 2662281592, 1728500627, 1194574818 } }, + // Matrix for nskip = 5 * 8 ^ 59: + { { 1447842232, 184782823, 1797257364 }, + { 2190899193, 2854828033, 184782823 }, + { 4138436503, 3783089951, 2854828033 } }, + // Matrix for nskip = 6 * 8 ^ 59: + { { 3892495210, 2262141136, 1078367555 }, + { 3549231332, 2559113701, 2262141136 }, + { 4146978688, 2236162592, 2559113701 } }, + // Matrix for nskip = 7 * 8 ^ 59: + { { 1510077366, 825286037, 2959985729 }, + { 830287146, 781759955, 825286037 }, + { 359509185, 3182735706, 781759955 } }, + // Matrix for nskip = 1 * 8 ^ 60: + { { 210787847, 1189120688, 2848040407 }, + { 1087786165, 2343328484, 1189120688 }, + { 3465141330, 2893041005, 2343328484 } }, + // Matrix for nskip = 2 * 8 ^ 60: + { { 3438170226, 3236285682, 962036916 }, + { 2873263091, 215280489, 3236285682 }, + { 730413847, 1474823842, 215280489 } }, + // Matrix for nskip = 3 * 8 ^ 60: + { { 1877599976, 489218847, 1841260926 }, + { 1267710679, 4177426677, 489218847 }, + { 3908192573, 1193948814, 4177426677 } }, + // Matrix for nskip = 4 * 8 ^ 60: + { { 1566461658, 133010024, 2886695328 }, + { 2835827516, 653809404, 133010024 }, + { 3082882924, 3710942807, 653809404 } }, + // Matrix for nskip = 5 * 8 ^ 60: + { { 1018639212, 4003411060, 3748771156 }, + { 933110981, 1484000297, 4003411060 }, + { 3415991698, 3188783681, 1484000297 } }, + // Matrix for nskip = 6 * 8 ^ 60: + { { 2630823869, 3185784250, 1624263326 }, + { 1151112872, 440283001, 3185784250 }, + { 4029103059, 1089550911, 440283001 } }, + // Matrix for nskip = 7 * 8 ^ 60: + { { 2558003006, 4161490031, 868072046 }, + { 2993166332, 1972186265, 4161490031 }, + { 1890899803, 3731240792, 1972186265 } }, + // Matrix for nskip = 1 * 8 ^ 61: + { { 4201558916, 1263786956, 326001602 }, + { 762846463, 621546357, 1263786956 }, + { 2697142404, 1156650856, 621546357 } }, + // Matrix for nskip = 2 * 8 ^ 61: + { { 2655768102, 2339029465, 2430211448 }, + { 2669906627, 403962847, 2339029465 }, + { 1483118807, 639660658, 403962847 } }, + // Matrix for nskip = 3 * 8 ^ 61: + { { 343789192, 2523152864, 3692813188 }, + { 4182218791, 1387544806, 2523152864 }, + { 3364170107, 1607749365, 1387544806 } }, + // Matrix for nskip = 4 * 8 ^ 61: + { { 3508595200, 4228486662, 754946994 }, + { 1913148390, 3500531602, 4228486662 }, + { 24637, 3773159052, 3500531602 } }, + // Matrix for nskip = 5 * 8 ^ 61: + { { 1767736432, 2782451483, 925961005 }, + { 1898573829, 779641045, 2782451483 }, + { 4172425777, 3053709304, 779641045 } }, + // Matrix for nskip = 6 * 8 ^ 61: + { { 917982480, 676540794, 3402535509 }, + { 1997794025, 3184854268, 676540794 }, + { 2501974390, 2557204628, 3184854268 } }, + // Matrix for nskip = 7 * 8 ^ 61: + { { 2265059434, 3533015776, 2085907395 }, + { 1105268907, 2837239505, 3533015776 }, + { 3031242459, 1173739788, 2837239505 } }, + // Matrix for nskip = 1 * 8 ^ 62: + { { 4024866227, 1143874914, 3205058469 }, + { 2970344133, 2873927273, 1143874914 }, + { 2167114735, 4095476435, 2873927273 } }, + // Matrix for nskip = 2 * 8 ^ 62: + { { 1479401095, 2958366486, 3027708794 }, + { 2704486034, 3574053987, 2958366486 }, + { 3630964515, 1276667706, 3574053987 } }, + // Matrix for nskip = 3 * 8 ^ 62: + { { 3471121, 4212261536, 2870367456 }, + { 3210276198, 3855580426, 4212261536 }, + { 2974755971, 3723431054, 3855580426 } }, + // Matrix for nskip = 4 * 8 ^ 62: + { { 2035927380, 1363628533, 818363998 }, + { 3023327955, 3968427114, 1363628533 }, + { 1284825950, 2871663372, 3968427114 } }, + // Matrix for nskip = 5 * 8 ^ 62: + { { 4289867114, 1817891047, 2823353497 }, + { 910331225, 3868760780, 1817891047 }, + { 2783151834, 2379034525, 3868760780 } }, + // Matrix for nskip = 6 * 8 ^ 62: + { { 2979837612, 1089982006, 1663630835 }, + { 709699817, 1486004709, 1089982006 }, + { 1956455708, 1787357723, 1486004709 } }, + // Matrix for nskip = 7 * 8 ^ 62: + { { 2852981955, 2215534550, 574323950 }, + { 1169533157, 2975065186, 2215534550 }, + { 2290801870, 428188634, 2975065186 } }, + // Matrix for nskip = 1 * 8 ^ 63: + { { 3827747418, 3897287251, 4106993377 }, + { 1527779946, 3221052941, 3897287251 }, + { 4178727866, 4281160673, 3221052941 } }, + // Matrix for nskip = 2 * 8 ^ 63: + { { 1174358892, 2835476193, 959978619 }, + { 850076464, 3774782533, 2835476193 }, + { 3880910680, 3237990203, 3774782533 } }, + // Matrix for nskip = 3 * 8 ^ 63: + { { 1400690756, 823435890, 1896847210 }, + { 3000499818, 1124911735, 823435890 }, + { 1381972838, 2683742666, 1124911735 } }, + // Matrix for nskip = 4 * 8 ^ 63: + { { 3128011728, 1998893251, 1400155768 }, + { 1430713735, 2850730926, 1998893251 }, + { 1073801764, 2374744218, 2850730926 } }, + // Matrix for nskip = 5 * 8 ^ 63: + { { 1152423219, 3000721466, 850809698 }, + { 764299143, 3684505492, 3000721466 }, + { 3524599640, 1299858048, 3684505492 } }, + // Matrix for nskip = 6 * 8 ^ 63: + { { 2188428625, 3090564778, 4205068615 }, + { 1911908313, 34180751, 3090564778 }, + { 3382776937, 194682771, 34180751 } }, + // Matrix for nskip = 7 * 8 ^ 63: + { { 3050396426, 627769205, 3010308075 }, + { 987718671, 3026731980, 627769205 }, + { 3527778260, 4200640347, 3026731980 } }, + // Matrix for nskip = 1 * 8 ^ 64: + { { 364496809, 3951443831, 2338985995 }, + { 2365728271, 1745134545, 3951443831 }, + { 1076500940, 1589192585, 1745134545 } }, + // Matrix for nskip = 2 * 8 ^ 64: + { { 3304498837, 1325046906, 3381970501 }, + { 1563368115, 3116266625, 1325046906 }, + { 244825785, 4251678855, 3116266625 } }, + // Matrix for nskip = 3 * 8 ^ 64: + { { 4133678667, 2048440215, 4035662430 }, + { 4086919994, 519191900, 2048440215 }, + { 2789936683, 4051608893, 519191900 } }, + // Matrix for nskip = 4 * 8 ^ 64: + { { 3603289991, 1324164821, 1776019579 }, + { 1734804890, 3151589272, 1324164821 }, + { 2411297223, 3296772386, 3151589272 } }, + // Matrix for nskip = 5 * 8 ^ 64: + { { 2599419541, 2726072264, 662164094 }, + { 1554537872, 2065618870, 2726072264 }, + { 1049180268, 439080215, 2065618870 } }, + // Matrix for nskip = 6 * 8 ^ 64: + { { 1694996757, 2289284793, 2258832764 }, + { 1982364129, 3971544391, 2289284793 }, + { 1140613093, 2605325759, 3971544391 } }, + // Matrix for nskip = 7 * 8 ^ 64: + { { 3344108032, 1353133572, 2611828466 }, + { 729814057, 219879593, 1353133572 }, + { 1513768211, 1797897504, 219879593 } } }, + // Matrix for nskip = 1 * 8 ^ 0: + { { { 0, 1, 0 }, { 0, 0, 1 }, { 4293573854, 0, 527612 } }, + // Matrix for nskip = 2 * 8 ^ 0: + { { 0, 0, 1 }, { 4293573854, 0, 527612 }, { 2706407399, 4293573854, 3497978192 } }, + // Matrix for nskip = 3 * 8 ^ 0: + { { 4293573854, 0, 527612 }, + { 2706407399, 4293573854, 3497978192 }, + { 1431525864, 2706407399, 3281754271 } }, + // Matrix for nskip = 4 * 8 ^ 0: + { { 2706407399, 4293573854, 3497978192 }, + { 1431525864, 2706407399, 3281754271 }, + { 97673890, 1431525864, 1673476130 } }, + // Matrix for nskip = 5 * 8 ^ 0: + { { 1431525864, 2706407399, 3281754271 }, + { 97673890, 1431525864, 1673476130 }, + { 2680076935, 97673890, 1430724370 } }, + // Matrix for nskip = 6 * 8 ^ 0: + { { 97673890, 1431525864, 1673476130 }, + { 2680076935, 97673890, 1430724370 }, + { 3405842137, 2680076935, 893509979 } }, + // Matrix for nskip = 7 * 8 ^ 0: + { { 2680076935, 97673890, 1430724370 }, + { 3405842137, 2680076935, 893509979 }, + { 4035147174, 3405842137, 3280220074 } }, + // Matrix for nskip = 1 * 8 ^ 1: + { { 3405842137, 2680076935, 893509979 }, + { 4035147174, 3405842137, 3280220074 }, + { 2623373296, 4035147174, 361718588 } }, + // Matrix for nskip = 2 * 8 ^ 1: + { { 818368950, 3790774567, 3542344109 }, + { 1817134745, 818368950, 3321940838 }, + { 3493477402, 1817134745, 2854655037 } }, + // Matrix for nskip = 3 * 8 ^ 1: + { { 508190223, 940389731, 295549677 }, + { 548891792, 508190223, 4243623497 }, + { 1618914183, 548891792, 2585942386 } }, + // Matrix for nskip = 4 * 8 ^ 1: + { { 498682467, 2928649385, 811441367 }, + { 1777037472, 498682467, 479207863 }, + { 3058260025, 1777037472, 1528225099 } }, + // Matrix for nskip = 5 * 8 ^ 1: + { { 1605006689, 1112484358, 2137070446 }, + { 3785946674, 1605006689, 1949907406 }, + { 3243030173, 3785946674, 2339202713 } }, + // Matrix for nskip = 6 * 8 ^ 1: + { { 1603012465, 493710616, 1996495269 }, + { 3369502947, 1603012465, 1576432507 }, + { 3762770058, 3369502947, 254897698 } }, + // Matrix for nskip = 7 * 8 ^ 1: + { { 1138020476, 4025114134, 3077305804 }, + { 4152260747, 1138020476, 1057298006 }, + { 1828211552, 4152260747, 3984471979 } }, + // Matrix for nskip = 1 * 8 ^ 2: + { { 3893311647, 3140922085, 64039185 }, + { 82107183, 3893311647, 2655465224 }, + { 1674879036, 82107183, 1089381262 } }, + // Matrix for nskip = 2 * 8 ^ 2: + { { 28639152, 3496041927, 2231910770 }, + { 3174683233, 28639152, 2828785870 }, + { 3681140872, 3174683233, 3910194649 } }, + // Matrix for nskip = 3 * 8 ^ 2: + { { 3488684910, 1250231333, 763303055 }, + { 681409874, 3488684910, 751154769 }, + { 3783909260, 681409874, 1465244270 } }, + // Matrix for nskip = 4 * 8 ^ 2: + { { 1463826069, 300842059, 3313769518 }, + { 1799677538, 1463826069, 3174861078 }, + { 1882279394, 1799677538, 3509975160 } }, + // Matrix for nskip = 5 * 8 ^ 2: + { { 2793448161, 3690337147, 4181759810 }, + { 514622120, 2793448161, 3027286223 }, + { 241620347, 514622120, 1328063696 } }, + // Matrix for nskip = 6 * 8 ^ 2: + { { 3250099852, 3207068910, 3709263791 }, + { 2342747328, 3250099852, 3729690850 }, + { 3983203494, 2342747328, 1023622970 } }, + // Matrix for nskip = 7 * 8 ^ 2: + { { 3136295372, 3178055245, 2818424094 }, + { 2036073935, 3136295372, 3231583326 }, + { 1782478065, 2036073935, 1053332972 } }, + // Matrix for nskip = 1 * 8 ^ 3: + { { 2092194020, 184076987, 2202401252 }, + { 3103629604, 2092194020, 3409560232 }, + { 4257445059, 3103629604, 2390202783 } }, + // Matrix for nskip = 2 * 8 ^ 3: + { { 812917091, 2574011276, 4168802395 }, + { 209817750, 812917091, 2974870628 }, + { 3238802184, 209817750, 3692836406 } }, + // Matrix for nskip = 3 * 8 ^ 3: + { { 1621943577, 2244624888, 38864005 }, + { 3618177584, 1621943577, 3295260066 }, + { 414159965, 3618177584, 1095692911 } }, + // Matrix for nskip = 4 * 8 ^ 3: + { { 477309738, 3314523413, 3442242150 }, + { 2755731404, 477309738, 2782713347 }, + { 1606221490, 2755731404, 1033463096 } }, + // Matrix for nskip = 5 * 8 ^ 3: + { { 3233499061, 2494617440, 1002517819 }, + { 3026123612, 3233499061, 3338202446 }, + { 1979145017, 3026123612, 3790308130 } }, + // Matrix for nskip = 6 * 8 ^ 3: + { { 2567113113, 781663248, 3993869449 }, + { 402756912, 2567113113, 2817097718 }, + { 3190930010, 402756912, 2884691291 } }, + // Matrix for nskip = 7 * 8 ^ 3: + { { 2223683788, 4195752245, 2738363134 }, + { 1171605168, 2223683788, 3904649711 }, + { 2631005941, 1171605168, 3445807882 } }, + // Matrix for nskip = 1 * 8 ^ 4: + { { 2155469603, 3326516116, 3843369786 }, + { 288604458, 2155469603, 571673571 }, + { 1501677614, 288604458, 2928213494 } }, + // Matrix for nskip = 2 * 8 ^ 4: + { { 2082469029, 749754403, 3963963316 }, + { 2764859700, 2082469029, 3576428059 }, + { 2840894706, 2764859700, 1782279859 } }, + // Matrix for nskip = 3 * 8 ^ 4: + { { 1583407457, 2056027805, 55614242 }, + { 2405645826, 1583407457, 1737043333 }, + { 1118910623, 2405645826, 1180559812 } }, + // Matrix for nskip = 4 * 8 ^ 4: + { { 3760163766, 1041986082, 1799196192 }, + { 1022129134, 3760163766, 1332558840 }, + { 276873446, 1022129134, 3979423632 } }, + // Matrix for nskip = 5 * 8 ^ 4: + { { 1438626566, 3619082489, 1569836243 }, + { 3671597039, 1438626566, 907924984 }, + { 3732297029, 3671597039, 1221779212 } }, + // Matrix for nskip = 6 * 8 ^ 4: + { { 483787924, 3115606677, 2374703971 }, + { 117552025, 483787924, 4234241969 }, + { 774331833, 117552025, 530787287 } }, + // Matrix for nskip = 7 * 8 ^ 4: + { { 955925224, 1961750426, 3644821859 }, + { 213414981, 955925224, 927956770 }, + { 1671634731, 213414981, 4186423122 } }, + // Matrix for nskip = 1 * 8 ^ 5: + { { 1021313167, 1312544548, 1716381787 }, + { 3037868518, 1021313167, 199085085 }, + { 2582787611, 3037868518, 3539882179 } }, + // Matrix for nskip = 2 * 8 ^ 5: + { { 2569413030, 1631336015, 2594942403 }, + { 1030618503, 2569413030, 3467650326 }, + { 1998739584, 1030618503, 3174552073 } }, + // Matrix for nskip = 3 * 8 ^ 5: + { { 2179955734, 1825159949, 1082151624 }, + { 937147983, 2179955734, 978382746 }, + { 2629591623, 937147983, 3579678559 } }, + // Matrix for nskip = 4 * 8 ^ 5: + { { 2334639309, 3114094203, 601680947 }, + { 2110199318, 2334639309, 678342865 }, + { 1649523168, 2110199318, 2154948056 } }, + // Matrix for nskip = 5 * 8 ^ 5: + { { 2715012491, 247412130, 1566452082 }, + { 3425439428, 2715012491, 3004133824 }, + { 1615468474, 3425439428, 588082730 } }, + // Matrix for nskip = 6 * 8 ^ 5: + { { 2654502125, 654123598, 3954383978 }, + { 2454987531, 2654502125, 161781366 }, + { 3631058630, 2454987531, 2718719935 } }, + // Matrix for nskip = 7 * 8 ^ 5: + { { 2620087047, 1022484731, 3275546712 }, + { 4119759001, 2620087047, 1849544363 }, + { 1245152096, 4119759001, 2978477502 } }, + // Matrix for nskip = 1 * 8 ^ 6: + { { 563657176, 191330473, 1641595774 }, + { 780563537, 563657176, 3029522338 }, + { 2037330914, 780563537, 2084602709 } }, + // Matrix for nskip = 2 * 8 ^ 6: + { { 3414769923, 1968799026, 2238126504 }, + { 832866376, 3414769923, 3754780168 }, + { 2165145850, 832866376, 1594768331 } }, + // Matrix for nskip = 3 * 8 ^ 6: + { { 1457310151, 2262086849, 2480319255 }, + { 1778576621, 1457310151, 367796024 }, + { 444536774, 1778576621, 873301213 } }, + // Matrix for nskip = 4 * 8 ^ 6: + { { 1646861218, 2317984620, 2301581548 }, + { 2672536210, 1646861218, 359763062 }, + { 2391283983, 2672536210, 1885870777 } }, + // Matrix for nskip = 5 * 8 ^ 6: + { { 2962497351, 1089931025, 970191811 }, + { 2050228336, 2962497351, 1568166288 }, + { 3288162415, 2050228336, 3921597644 } }, + // Matrix for nskip = 6 * 8 ^ 6: + { { 2468196470, 3544275509, 3557597196 }, + { 3893425026, 2468196470, 2061293842 }, + { 2019325804, 3893425026, 2905314 } }, + // Matrix for nskip = 7 * 8 ^ 6: + { { 3407411651, 4206194937, 989129012 }, + { 1280115996, 3407411651, 1843205351 }, + { 752661975, 1280115996, 693779416 } }, + // Matrix for nskip = 1 * 8 ^ 7: + { { 841254072, 3765813448, 1635365181 }, + { 2013240130, 841254072, 605925849 }, + { 3743932305, 2013240130, 400681955 } }, + // Matrix for nskip = 2 * 8 ^ 7: + { { 1930213004, 2072952279, 3077694794 }, + { 3579956569, 1930213004, 2478539210 }, + { 1960229502, 3579956569, 1455652656 } }, + // Matrix for nskip = 3 * 8 ^ 7: + { { 490241598, 1155806426, 2341304300 }, + { 1821354750, 490241598, 2364275695 }, + { 3717764728, 1821354750, 1349151461 } }, + // Matrix for nskip = 4 * 8 ^ 7: + { { 1097613522, 1784540933, 1194440107 }, + { 321747515, 1097613522, 1225209584 }, + { 74521379, 321747515, 4288531000 } }, + // Matrix for nskip = 5 * 8 ^ 7: + { { 3795899570, 3294470896, 2568537852 }, + { 1615892324, 3795899570, 2277651644 }, + { 245018475, 1615892324, 3269831184 } }, + // Matrix for nskip = 6 * 8 ^ 7: + { { 2284610128, 1711688841, 2988405862 }, + { 1861018675, 2284610128, 3450880655 }, + { 4077631310, 1861018675, 2595646099 } }, + // Matrix for nskip = 7 * 8 ^ 7: + { { 1338063869, 4236188627, 4005334159 }, + { 2199059659, 1338063869, 3613475430 }, + { 954928333, 2199059659, 1383222658 } }, + // Matrix for nskip = 1 * 8 ^ 8: + { { 143812745, 3254530816, 3514348856 }, + { 769295000, 143812745, 2468210728 }, + { 1927161272, 769295000, 522705580 } }, + // Matrix for nskip = 2 * 8 ^ 8: + { { 2692035063, 2596905012, 1643240704 }, + { 1103432342, 2692035063, 1446182108 }, + { 4161111774, 1103432342, 3076435551 } }, + // Matrix for nskip = 3 * 8 ^ 8: + { { 1809137988, 2412502608, 3993875038 }, + { 1332423877, 1809137988, 3101816103 }, + { 1366553339, 1332423877, 2986424418 } }, + // Matrix for nskip = 4 * 8 ^ 8: + { { 2375319030, 1391532370, 3742334018 }, + { 1202100604, 2375319030, 4098434768 }, + { 2327872488, 1202100604, 1471526950 } }, + // Matrix for nskip = 5 * 8 ^ 8: + { { 953526753, 3517620599, 1558514368 }, + { 3674658855, 953526753, 1517070807 }, + { 828283166, 3674658855, 2689974385 } }, + // Matrix for nskip = 6 * 8 ^ 8: + { { 3063334100, 3228801559, 269715831 }, + { 612058994, 3063334100, 4143597212 }, + { 1918225488, 612058994, 2055175984 } }, + // Matrix for nskip = 7 * 8 ^ 8: + { { 2623568215, 482061697, 191091208 }, + { 2499397071, 2623568215, 2970642011 }, + { 759749547, 2499397071, 3510580843 } }, + // Matrix for nskip = 1 * 8 ^ 9: + { { 4269164791, 2795313144, 2507855960 }, + { 4245372460, 4269164791, 4094914553 }, + { 3873219634, 4245372460, 1473695507 } }, + // Matrix for nskip = 2 * 8 ^ 9: + { { 513890845, 1208902926, 2870530442 }, + { 1984873167, 513890845, 1257532340 }, + { 1212627640, 1984873167, 2354363842 } }, + // Matrix for nskip = 3 * 8 ^ 9: + { { 3386048256, 4196280201, 3121820178 }, + { 2926727276, 3386048256, 2790144637 }, + { 3970110476, 2926727276, 3495704635 } }, + // Matrix for nskip = 4 * 8 ^ 9: + { { 1848364568, 1552116673, 3496528455 }, + { 4160778291, 1848364568, 141769900 }, + { 3611019106, 4160778291, 596424080 } }, + // Matrix for nskip = 5 * 8 ^ 9: + { { 4194097650, 3986230829, 3091752508 }, + { 3352554321, 4194097650, 4041363667 }, + { 3822925061, 3352554321, 3748054631 } }, + // Matrix for nskip = 6 * 8 ^ 9: + { { 1292986218, 172755364, 997232463 }, + { 1505642955, 1292986218, 4112978448 }, + { 1757204931, 1505642955, 3038511100 } }, + // Matrix for nskip = 7 * 8 ^ 9: + { { 3805104355, 3540279669, 2118304338 }, + { 1984875159, 3805104355, 3000869736 }, + { 6466700, 1984875159, 1778898381 } }, + // Matrix for nskip = 1 * 8 ^ 10: + { { 364070020, 3520039729, 837362349 }, + { 2544671570, 364070020, 2188646679 }, + { 163978331, 2544671570, 672947816 } }, + // Matrix for nskip = 2 * 8 ^ 10: + { { 1192700714, 3968150021, 298357363 }, + { 635565666, 1192700714, 2589432341 }, + { 2548654227, 635565666, 3531570992 } }, + // Matrix for nskip = 3 * 8 ^ 10: + { { 3438963520, 1845346034, 2575726025 }, + { 2187600669, 3438963520, 958916489 }, + { 2672427080, 2187600669, 3420061274 } }, + // Matrix for nskip = 4 * 8 ^ 10: + { { 2709640529, 676525399, 875361870 }, + { 1315499519, 2709640529, 3842690720 }, + { 3300994644, 1315499519, 2446760804 } }, + // Matrix for nskip = 5 * 8 ^ 10: + { { 1292317767, 393678487, 143711415 }, + { 1162526988, 1292317767, 1311572745 }, + { 344898630, 1162526988, 1362796547 } }, + // Matrix for nskip = 6 * 8 ^ 10: + { { 2857812374, 598000082, 2114605560 }, + { 3454872661, 2857812374, 2738653578 }, + { 2522086851, 3454872661, 1190449620 } }, + // Matrix for nskip = 7 * 8 ^ 10: + { { 2614530149, 753841941, 146778273 }, + { 2511297323, 2614530149, 588764284 }, + { 1785429779, 2511297323, 1269211096 } }, + // Matrix for nskip = 1 * 8 ^ 11: + { { 2742149264, 1410604392, 3032350755 }, + { 3774935330, 2742149264, 597633965 }, + { 4085935803, 3774935330, 3952463556 } }, + // Matrix for nskip = 2 * 8 ^ 11: + { { 3878579563, 845297523, 1721916511 }, + { 2077922420, 3878579563, 3651360351 }, + { 2177255734, 2077922420, 3791239282 } }, + // Matrix for nskip = 3 * 8 ^ 11: + { { 2642777370, 1064863813, 4046131253 }, + { 2032494710, 2642777370, 3511906271 }, + { 2787706468, 2032494710, 1602633162 } }, + // Matrix for nskip = 4 * 8 ^ 11: + { { 1570315355, 4252790045, 3522351060 }, + { 2324624266, 1570315355, 3594939336 }, + { 1725087354, 2324624266, 1338343327 } }, + // Matrix for nskip = 5 * 8 ^ 11: + { { 3128806513, 3431512800, 3791370211 }, + { 26016991, 3128806513, 1182007239 }, + { 2629261386, 26016991, 1219288409 } }, + // Matrix for nskip = 6 * 8 ^ 11: + { { 2323129699, 2040722667, 4032945011 }, + { 1824515104, 2323129699, 783304238 }, + { 1910382756, 1824515104, 2009721680 } }, + // Matrix for nskip = 7 * 8 ^ 11: + { { 495056704, 1303223717, 299029371 }, + { 3001848199, 495056704, 2298546607 }, + { 528121192, 3001848199, 3574765936 } }, + // Matrix for nskip = 1 * 8 ^ 12: + { { 2305761589, 381933244, 3663579047 }, + { 1355307047, 2305761589, 313617972 }, + { 992174375, 1355307047, 3881593435 } }, + // Matrix for nskip = 2 * 8 ^ 12: + { { 1667857811, 1564715297, 2263851601 }, + { 3791771273, 1667857811, 4196134923 }, + { 3347975047, 3791771273, 615040705 } }, + // Matrix for nskip = 3 * 8 ^ 12: + { { 2699274746, 2208033721, 3314336764 }, + { 1723493827, 2699274746, 3721738282 }, + { 3116429712, 1723493827, 763211059 } }, + // Matrix for nskip = 4 * 8 ^ 12: + { { 4093947334, 3454015638, 2815567716 }, + { 4261953004, 4093947334, 3973733876 }, + { 2979573134, 4261953004, 3757047667 } }, + // Matrix for nskip = 5 * 8 ^ 12: + { { 1497333242, 3837209858, 4043986454 }, + { 3928412309, 1497333242, 4232950837 }, + { 868538065, 3928412309, 3223762258 } }, + // Matrix for nskip = 6 * 8 ^ 12: + { { 4178728130, 2981026540, 3927272953 }, + { 668310420, 4178728130, 551557198 }, + { 3532851694, 668310420, 4119399398 } }, + // Matrix for nskip = 7 * 8 ^ 12: + { { 4121879899, 2179415297, 3607008098 }, + { 243696529, 4121879899, 168490644 }, + { 3444486351, 243696529, 752516370 } }, + // Matrix for nskip = 1 * 8 ^ 13: + { { 250120061, 570149551, 1513430926 }, + { 3178644752, 250120061, 1701869032 }, + { 4172515680, 3178644752, 4213855850 } }, + // Matrix for nskip = 2 * 8 ^ 13: + { { 4158106802, 3062358456, 1815738463 }, + { 1379176112, 4158106802, 3926509890 }, + { 2842564878, 1379176112, 2852219546 } }, + // Matrix for nskip = 3 * 8 ^ 13: + { { 4056930326, 2130453857, 3298513997 }, + { 3059400883, 4056930326, 439468763 }, + { 546163799, 3059400883, 1884270041 } }, + // Matrix for nskip = 4 * 8 ^ 13: + { { 931848746, 256263523, 2633569246 }, + { 3284646837, 931848746, 2567084715 }, + { 415258465, 3284646837, 2017565947 } }, + // Matrix for nskip = 5 * 8 ^ 13: + { { 239941751, 4065438988, 4260302551 }, + { 3480241466, 239941751, 1576122049 }, + { 4073589963, 3480241466, 2593293965 } }, + // Matrix for nskip = 6 * 8 ^ 13: + { { 507915211, 625612469, 3733827320 }, + { 3909587424, 507915211, 3313512626 }, + { 1707582600, 3909587424, 985910059 } }, + // Matrix for nskip = 7 * 8 ^ 13: + { { 3287778427, 3984689764, 3572719740 }, + { 207904085, 3287778427, 1330617931 }, + { 1894788630, 207904085, 1656936419 } }, + // Matrix for nskip = 1 * 8 ^ 14: + { { 1648005210, 1032291296, 3987397422 }, + { 1831496020, 1648005210, 2829448427 }, + { 1821082272, 1831496020, 2917140265 } }, + // Matrix for nskip = 2 * 8 ^ 14: + { { 4161327077, 489964129, 3870847744 }, + { 1669447863, 4161327077, 4292947198 }, + { 1522417114, 1669447863, 2652286672 } }, + // Matrix for nskip = 3 * 8 ^ 14: + { { 655280634, 3675619486, 3487203083 }, + { 3658400031, 655280634, 4093432727 }, + { 3338913609, 3658400031, 2005464907 } }, + // Matrix for nskip = 4 * 8 ^ 14: + { { 1270934555, 3136631324, 505612043 }, + { 2981474723, 1270934555, 2528619024 }, + { 625182639, 2981474723, 1008985039 } }, + // Matrix for nskip = 5 * 8 ^ 14: + { { 2670739471, 1317142118, 928068368 }, + { 3334643457, 2670739471, 3298861790 }, + { 3116973979, 3334643457, 4091848087 } }, + // Matrix for nskip = 6 * 8 ^ 14: + { { 87174298, 3714928458, 3674535785 }, + { 3591445536, 87174298, 3557842564 }, + { 2600409828, 3591445536, 3509905000 } }, + // Matrix for nskip = 7 * 8 ^ 14: + { { 1374849292, 3669747751, 313867341 }, + { 2805321474, 1374849292, 3672378692 }, + { 862662086, 2805321474, 1269888877 } }, + // Matrix for nskip = 1 * 8 ^ 15: + { { 280996820, 143706137, 3013099060 }, + { 1797675893, 280996820, 3743985508 }, + { 1123794455, 1797675893, 2460119169 } }, + // Matrix for nskip = 2 * 8 ^ 15: + { { 919218027, 4154920441, 1125672685 }, + { 3933041881, 919218027, 474242849 }, + { 564891116, 3933041881, 2263904321 } }, + // Matrix for nskip = 3 * 8 ^ 15: + { { 4046953169, 707039159, 59087677 }, + { 552285455, 4046953169, 3367709189 }, + { 1558638678, 552285455, 3541844079 } }, + // Matrix for nskip = 4 * 8 ^ 15: + { { 2920112852, 1965329198, 1177141043 }, + { 2135250851, 2920112852, 969184056 }, + { 296035385, 2135250851, 4267827987 } }, + // Matrix for nskip = 5 * 8 ^ 15: + { { 3182682829, 216191227, 2317042610 }, + { 3166912454, 3182682829, 3895260799 }, + { 3316963881, 3166912454, 2773111558 } }, + // Matrix for nskip = 6 * 8 ^ 15: + { { 4005961945, 962333604, 1596766252 }, + { 155090437, 4005961945, 3465811606 }, + { 995757623, 155090437, 842864023 } }, + // Matrix for nskip = 7 * 8 ^ 15: + { { 3616509225, 3195052585, 2901642782 }, + { 4257279454, 3616509225, 3209952933 }, + { 159699513, 4257279454, 746020360 } }, + // Matrix for nskip = 1 * 8 ^ 16: + { { 1481142942, 4120754772, 1088557292 }, + { 265491023, 1481142942, 2860005744 }, + { 301796252, 265491023, 1935975979 } }, + // Matrix for nskip = 2 * 8 ^ 16: + { { 2111859033, 2813610100, 1001476468 }, + { 73849832, 2111859033, 3980799998 }, + { 3330206241, 73849832, 1933943506 } }, + // Matrix for nskip = 3 * 8 ^ 16: + { { 4238802520, 1791251057, 3659825373 }, + { 756158319, 4238802520, 1208877520 }, + { 3666294602, 756158319, 1800377045 } }, + // Matrix for nskip = 4 * 8 ^ 16: + { { 1781286360, 3661231931, 3509383709 }, + { 2753158871, 1781286360, 3119883109 }, + { 3576525143, 2753158871, 551079002 } }, + // Matrix for nskip = 5 * 8 ^ 16: + { { 1150902763, 3730191199, 946744850 }, + { 3422735839, 1150902763, 2750435170 }, + { 3792794843, 3422735839, 808249292 } }, + // Matrix for nskip = 6 * 8 ^ 16: + { { 429107478, 1467997203, 689359610 }, + { 3244671951, 429107478, 2795337511 }, + { 3397069741, 3244671951, 186846111 } }, + // Matrix for nskip = 7 * 8 ^ 16: + { { 1453148331, 352897577, 3494583787 }, + { 2340848640, 1453148331, 3699044308 }, + { 3239904192, 2340848640, 209181640 } }, + // Matrix for nskip = 1 * 8 ^ 17: + { { 1185024844, 587779104, 1004942725 }, + { 3763632860, 1185024844, 947424568 }, + { 3811666068, 3763632860, 2352253462 } }, + // Matrix for nskip = 2 * 8 ^ 17: + { { 1310227170, 218138208, 3172947233 }, + { 766129426, 1310227170, 1808643264 }, + { 2226659371, 766129426, 3853798112 } }, + // Matrix for nskip = 3 * 8 ^ 17: + { { 3141996820, 528748361, 1701083808 }, + { 2360837423, 3141996820, 2513545590 }, + { 1425244435, 2360837423, 4192496132 } }, + // Matrix for nskip = 4 * 8 ^ 17: + { { 2230902378, 4243560874, 2491962392 }, + { 3836629116, 2230902378, 3637515403 }, + { 2846140932, 3836629116, 3083355464 } }, + // Matrix for nskip = 5 * 8 ^ 17: + { { 506476814, 1267508030, 152968246 }, + { 1117668151, 506476814, 2848688169 }, + { 3001214254, 1117668151, 3940649164 } }, + // Matrix for nskip = 6 * 8 ^ 17: + { { 1544421101, 772024440, 2364160468 }, + { 2733679040, 1544421101, 965008581 }, + { 2290142084, 2733679040, 3167919795 } }, + // Matrix for nskip = 7 * 8 ^ 17: + { { 2195717687, 3299928213, 1911548095 }, + { 3677807589, 2195717687, 2979544321 }, + { 1288751520, 3677807589, 1379093393 } }, + // Matrix for nskip = 1 * 8 ^ 18: + { { 999448569, 1464488480, 3344426626 }, + { 946166795, 999448569, 340856814 }, + { 3686999436, 946166795, 3231079441 } }, + // Matrix for nskip = 2 * 8 ^ 18: + { { 1226155368, 3477563770, 550006884 }, + { 2378667355, 1226155368, 1493409040 }, + { 260364836, 2378667355, 4133888397 } }, + // Matrix for nskip = 3 * 8 ^ 18: + { { 662024646, 2039234405, 3990280006 }, + { 2342461604, 662024646, 17023679 }, + { 1965981888, 2342461604, 1830518881 } }, + // Matrix for nskip = 4 * 8 ^ 18: + { { 1277901832, 310796286, 2818511068 }, + { 3088910653, 1277901832, 3303406025 }, + { 2507911914, 3088910653, 3712928074 } }, + // Matrix for nskip = 5 * 8 ^ 18: + { { 1103450261, 1722381279, 1394112836 }, + { 640743651, 1103450261, 198700731 }, + { 1095985628, 640743651, 2694625446 } }, + // Matrix for nskip = 6 * 8 ^ 18: + { { 4043182751, 1859059885, 1911031801 }, + { 2638851660, 4043182751, 4012210417 }, + { 783591639, 2638851660, 2188651115 } }, + // Matrix for nskip = 7 * 8 ^ 18: + { { 2318313639, 843870069, 2868175764 }, + { 3777361816, 2318313639, 4070019017 }, + { 2087410703, 3777361816, 2574355460 } }, + // Matrix for nskip = 1 * 8 ^ 19: + { { 481918378, 339570348, 1728801469 }, + { 1623163429, 481918378, 2209094694 }, + { 3146982514, 1623163429, 508445538 } }, + // Matrix for nskip = 2 * 8 ^ 19: + { { 3138921230, 2381863183, 1992357430 }, + { 1024510915, 3138921230, 2122851650 }, + { 1453455184, 1024510915, 941946604 } }, + // Matrix for nskip = 3 * 8 ^ 19: + { { 3235663883, 499846706, 3251827412 }, + { 801993191, 3235663883, 2207701640 }, + { 1201194185, 801993191, 2705683748 } }, + // Matrix for nskip = 4 * 8 ^ 19: + { { 2465372719, 1391015357, 3328905025 }, + { 1821933605, 2465372719, 1343489680 }, + { 3648970313, 1821933605, 1816599716 } }, + // Matrix for nskip = 5 * 8 ^ 19: + { { 582796091, 1306170361, 1574617829 }, + { 4167642903, 582796091, 284777447 }, + { 3124784671, 4167642903, 2539713186 } }, + // Matrix for nskip = 6 * 8 ^ 19: + { { 116486317, 2122591885, 1696181092 }, + { 381403852, 116486317, 2932149608 }, + { 3221291545, 381403852, 2742038256 } }, + // Matrix for nskip = 7 * 8 ^ 19: + { { 3035480468, 2182693760, 2351066479 }, + { 638141264, 3035480468, 100617977 }, + { 478641834, 638141264, 479301469 } }, + // Matrix for nskip = 1 * 8 ^ 20: + { { 118634664, 3358712512, 2492792220 }, + { 348833376, 118634664, 2495544591 }, + { 3235582254, 348833376, 4043157504 } }, + // Matrix for nskip = 2 * 8 ^ 20: + { { 2303067090, 3371139074, 1967771133 }, + { 598630070, 2303067090, 1819012637 }, + { 2049250561, 598630070, 4093044926 } }, + // Matrix for nskip = 3 * 8 ^ 20: + { { 897071837, 763331173, 3837362577 }, + { 294683328, 897071837, 2496877097 }, + { 2268904495, 294683328, 3496861697 } }, + // Matrix for nskip = 4 * 8 ^ 20: + { { 3035321857, 3971176093, 226779704 }, + { 3361614254, 3035321857, 2807125404 }, + { 326640887, 3361614254, 3147308542 } }, + // Matrix for nskip = 5 * 8 ^ 20: + { { 4010547095, 2725421511, 511986932 }, + { 1545732164, 4010547095, 2643845410 }, + { 2010134838, 1545732164, 3633977146 } }, + // Matrix for nskip = 6 * 8 ^ 20: + { { 3118026103, 1037137281, 1600236290 }, + { 2957620899, 3118026103, 433027378 }, + { 2926759199, 2957620899, 3989342054 } }, + // Matrix for nskip = 7 * 8 ^ 20: + { { 2423025801, 3089536821, 995021703 }, + { 3613148280, 2423025801, 241254395 }, + { 2857733472, 3613148280, 1868423350 } }, + // Matrix for nskip = 1 * 8 ^ 21: + { { 1774298149, 4179629947, 3145006948 }, + { 1688753503, 1774298149, 94869516 }, + { 2327946901, 1688753503, 2786835219 } }, + // Matrix for nskip = 2 * 8 ^ 21: + { { 185429251, 88142322, 3372328450 }, + { 1198432931, 185429251, 1527068783 }, + { 2880072915, 1198432931, 2782214191 } }, + // Matrix for nskip = 3 * 8 ^ 21: + { { 2610521617, 1116660734, 2002689706 }, + { 152508922, 2610521617, 2005955946 }, + { 3106947611, 152508922, 239569623 } }, + // Matrix for nskip = 4 * 8 ^ 21: + { { 127447080, 487724245, 2942566616 }, + { 2180042365, 127447080, 1722814040 }, + { 288658537, 2180042365, 4036691926 } }, + // Matrix for nskip = 5 * 8 ^ 21: + { { 3269833722, 2788004771, 1482042877 }, + { 834850082, 3269833722, 219243029 }, + { 3704080414, 834850082, 2784167151 } }, + // Matrix for nskip = 6 * 8 ^ 21: + { { 3956830949, 61587123, 1894752970 }, + { 1989171734, 3956830949, 3197042083 }, + { 457585003, 1989171734, 948838482 } }, + // Matrix for nskip = 7 * 8 ^ 21: + { { 1982687998, 3610851352, 1902386191 }, + { 2465097713, 1982687998, 1172472587 }, + { 1202471365, 2465097713, 3151246066 } }, + // Matrix for nskip = 1 * 8 ^ 22: + { { 1614979968, 1486547157, 1122661217 }, + { 3976346810, 1614979968, 2343603502 }, + { 3049605934, 3976346810, 440737492 } }, + // Matrix for nskip = 2 * 8 ^ 22: + { { 613698149, 3416334823, 3832821180 }, + { 1308958254, 613698149, 1338381534 }, + { 4058246217, 1308958254, 2070907998 } }, + // Matrix for nskip = 3 * 8 ^ 22: + { { 4069522778, 1558347771, 1555772973 }, + { 2924102885, 4069522778, 561176530 }, + { 566720713, 2924102885, 2660857604 } }, + // Matrix for nskip = 4 * 8 ^ 22: + { { 2575546527, 1033712257, 125034191 }, + { 2091411644, 2575546527, 226649669 }, + { 1198488263, 2091411644, 1522580506 } }, + // Matrix for nskip = 5 * 8 ^ 22: + { { 180639007, 1841709550, 234837148 }, + { 2219662691, 180639007, 4181748462 }, + { 3183232763, 2219662691, 2120135993 } }, + // Matrix for nskip = 6 * 8 ^ 22: + { { 4275704717, 2295071345, 1852983492 }, + { 3461773529, 4275704717, 417692359 }, + { 1477011348, 3461773529, 1587362209 } }, + // Matrix for nskip = 7 * 8 ^ 22: + { { 755069175, 2381439395, 890314398 }, + { 3019982523, 755069175, 572921618 }, + { 330076245, 3019982523, 2885887051 } }, + // Matrix for nskip = 1 * 8 ^ 23: + { { 1051614737, 227719572, 3725579556 }, + { 3910426444, 1051614737, 2075080920 }, + { 3357426062, 3910426444, 1473179318 } }, + // Matrix for nskip = 2 * 8 ^ 23: + { { 2999155498, 2971093563, 2685380188 }, + { 93938118, 2999155498, 4035265564 }, + { 3853931650, 93938118, 2034180250 } }, + // Matrix for nskip = 3 * 8 ^ 23: + { { 3543842569, 1469908890, 519769416 }, + { 3600765500, 3543842569, 1553393489 }, + { 60922281, 3600765500, 1226136476 } }, + // Matrix for nskip = 4 * 8 ^ 23: + { { 1253368368, 2860152458, 2836784419 }, + { 1656084047, 1253368368, 646811031 }, + { 3103367928, 1656084047, 3114448889 } }, + // Matrix for nskip = 5 * 8 ^ 23: + { { 2205916258, 1604698588, 3155610724 }, + { 2362004551, 2205916258, 181736283 }, + { 3847535541, 2362004551, 3814972479 } }, + // Matrix for nskip = 6 * 8 ^ 23: + { { 7725939, 1654580658, 4264117811 }, + { 1274240457, 7725939, 2108223515 }, + { 1813716775, 1274240457, 2141296207 } }, + // Matrix for nskip = 7 * 8 ^ 23: + { { 1828440339, 726307104, 566806600 }, + { 2069873554, 1828440339, 2003524657 }, + { 2528019064, 2069873554, 868624934 } }, + // Matrix for nskip = 1 * 8 ^ 24: + { { 2962469315, 4021086500, 2670244515 }, + { 299199825, 2962469315, 3624275162 }, + { 3634541206, 299199825, 1684552227 } }, + // Matrix for nskip = 2 * 8 ^ 24: + { { 804213223, 438999528, 3143925885 }, + { 1625976775, 804213223, 1494982903 }, + { 3498104358, 1625976775, 881729466 } }, + // Matrix for nskip = 3 * 8 ^ 24: + { { 2885386524, 2618720282, 4093772765 }, + { 1140571071, 2885386524, 2989367205 }, + { 2802821649, 1140571071, 742292537 } }, + // Matrix for nskip = 4 * 8 ^ 24: + { { 1547173514, 490999994, 918013965 }, + { 1312079237, 1547173514, 1905431135 }, + { 3784344293, 1312079237, 3643511238 } }, + // Matrix for nskip = 5 * 8 ^ 24: + { { 3363084915, 889964766, 2840623993 }, + { 485137636, 3363084915, 1563107974 }, + { 4117358359, 485137636, 2655518143 } }, + // Matrix for nskip = 6 * 8 ^ 24: + { { 2014523666, 1476325540, 1550754572 }, + { 588313388, 2014523666, 2691287218 }, + { 4248816946, 588313388, 1568942409 } }, + // Matrix for nskip = 7 * 8 ^ 24: + { { 2407332340, 3541076740, 1876171062 }, + { 1127328556, 2407332340, 3702106930 }, + { 1804600645, 1127328556, 2140373745 } }, + // Matrix for nskip = 1 * 8 ^ 25: + { { 3846994569, 2894966137, 1130633118 }, + { 4115190113, 3846994569, 777098754 }, + { 3088495692, 4115190113, 2193427908 } }, + // Matrix for nskip = 2 * 8 ^ 25: + { { 1511326704, 3759209742, 1610795712 }, + { 4292754251, 1511326704, 3889917532 }, + { 3859662829, 4292754251, 3708466080 } }, + // Matrix for nskip = 3 * 8 ^ 25: + { { 2721725192, 3847490931, 444351073 }, + { 429225403, 2721725192, 673508566 }, + { 387279730, 429225403, 3104869093 } }, + // Matrix for nskip = 4 * 8 ^ 25: + { { 972103006, 964807713, 878035866 }, + { 4248550197, 972103006, 1926628839 }, + { 1448629089, 4248550197, 3196114006 } }, + // Matrix for nskip = 5 * 8 ^ 25: + { { 549140019, 2935386277, 4206854109 }, + { 459549553, 549140019, 1011901572 }, + { 821145437, 459549553, 302470082 } }, + // Matrix for nskip = 6 * 8 ^ 25: + { { 907238901, 2926293232, 2865846472 }, + { 840689212, 907238901, 1249197731 }, + { 4278768404, 840689212, 3331097822 } }, + // Matrix for nskip = 7 * 8 ^ 25: + { { 105585154, 3513063153, 2552212444 }, + { 379969606, 105585154, 378686420 }, + { 3414457398, 379969606, 3084470277 } }, + // Matrix for nskip = 1 * 8 ^ 26: + { { 3497384788, 3174249442, 3182508868 }, + { 3864816447, 3497384788, 3038399593 }, + { 2546884738, 3864816447, 2980208068 } }, + // Matrix for nskip = 2 * 8 ^ 26: + { { 1776335558, 1189944887, 4095757548 }, + { 3813600746, 1776335558, 789475914 }, + { 4119698302, 3813600746, 2145357457 } }, + // Matrix for nskip = 3 * 8 ^ 26: + { { 1736653518, 945282763, 3568863651 }, + { 2539405616, 1736653518, 3870991887 }, + { 1676082014, 2539405616, 4282213129 } }, + // Matrix for nskip = 4 * 8 ^ 26: + { { 4022832294, 4130146837, 1942923647 }, + { 1675130777, 4022832294, 916677004 }, + { 4089786548, 1675130777, 116540512 } }, + // Matrix for nskip = 5 * 8 ^ 26: + { { 3414208535, 1938436883, 1996617380 }, + { 3508342845, 3414208535, 3024221061 }, + { 863275511, 3508342845, 3926625937 } }, + // Matrix for nskip = 6 * 8 ^ 26: + { { 943060309, 1550884686, 1524180490 }, + { 1603911046, 943060309, 659956132 }, + { 3864471824, 1603911046, 1981894197 } }, + // Matrix for nskip = 7 * 8 ^ 26: + { { 4039258344, 2877267458, 1263654722 }, + { 2264646264, 4039258344, 866786660 }, + { 3436002161, 2264646264, 1103279181 } }, + // Matrix for nskip = 1 * 8 ^ 27: + { { 165639584, 1205513289, 2037453462 }, + { 1444587280, 165639584, 161923120 }, + { 2617085459, 1444587280, 2006913311 } }, + // Matrix for nskip = 2 * 8 ^ 27: + { { 3458099202, 3062421748, 4052486999 }, + { 1064270720, 3458099202, 230768332 }, + { 4056228301, 1064270720, 2219267779 } }, + // Matrix for nskip = 3 * 8 ^ 27: + { { 4130534548, 3958841381, 2978123129 }, + { 3549040929, 4130534548, 624596665 }, + { 3007893075, 3549040929, 2033981581 } }, + // Matrix for nskip = 4 * 8 ^ 27: + { { 296275263, 3452455838, 2081462173 }, + { 1789143993, 296275263, 3463234943 }, + { 2097389984, 1789143993, 3447191459 } }, + // Matrix for nskip = 5 * 8 ^ 27: + { { 3690699991, 194807645, 3499022088 }, + { 895650639, 3690699991, 202155710 }, + { 3063493626, 895650639, 2818867049 } }, + // Matrix for nskip = 6 * 8 ^ 27: + { { 775854673, 2918396394, 2709062415 }, + { 2684216609, 775854673, 721391189 }, + { 4036938266, 2684216609, 1742271124 } }, + // Matrix for nskip = 7 * 8 ^ 27: + { { 3150458758, 4126093705, 1386916196 }, + { 3083923483, 3150458758, 2299677089 }, + { 1576871217, 3083923483, 1393814954 } }, + // Matrix for nskip = 1 * 8 ^ 28: + { { 2828288883, 3866690251, 410553827 }, + { 1587005542, 2828288883, 1469478670 }, + { 2766486018, 1587005542, 2627363449 } }, + // Matrix for nskip = 2 * 8 ^ 28: + { { 3288027530, 412403981, 2458742268 }, + { 4267121909, 3288027530, 138566505 }, + { 420803572, 4267121909, 4094554844 } }, + // Matrix for nskip = 3 * 8 ^ 28: + { { 2136361676, 3398888999, 2068559481 }, + { 3790597750, 2136361676, 3281478755 }, + { 4056706273, 3790597750, 1765993677 } }, + // Matrix for nskip = 4 * 8 ^ 28: + { { 3844599430, 2430152838, 3283485436 }, + { 2486244684, 3844599430, 4252427633 }, + { 3560842909, 2486244684, 3960267499 } }, + // Matrix for nskip = 5 * 8 ^ 28: + { { 3419145577, 107246070, 429885456 }, + { 1381214928, 3419145577, 1111366755 }, + { 767007913, 1381214928, 2270459619 } }, + // Matrix for nskip = 6 * 8 ^ 28: + { { 1494013447, 1485743041, 931794028 }, + { 3674972444, 1494013447, 2085831739 }, + { 62603161, 3674972444, 555083053 } }, + // Matrix for nskip = 7 * 8 ^ 28: + { { 1677686741, 1049056456, 3063490072 }, + { 3432517708, 1677686741, 1550912558 }, + { 3096606227, 3432517708, 349068991 } }, + // Matrix for nskip = 1 * 8 ^ 29: + { { 67933059, 1294996291, 2657888382 }, + { 513233413, 67933059, 1379805031 }, + { 44564058, 513233413, 86971645 } }, + // Matrix for nskip = 2 * 8 ^ 29: + { { 2732588524, 1866530072, 818237694 }, + { 2540507736, 2732588524, 3257104212 }, + { 1164400003, 2540507736, 1124501551 } }, + // Matrix for nskip = 3 * 8 ^ 29: + { { 1412660773, 1524580236, 2800129005 }, + { 3198153122, 1412660773, 3904718713 }, + { 2546401509, 3198153122, 386568104 } }, + // Matrix for nskip = 4 * 8 ^ 29: + { { 4199239222, 3155848463, 2121388468 }, + { 1135554501, 4199239222, 2056492193 }, + { 3251740389, 1135554501, 2343537248 } }, + // Matrix for nskip = 5 * 8 ^ 29: + { { 3239971958, 3891714065, 1807213249 }, + { 3694822198, 3239971958, 3557488352 }, + { 2750758637, 3694822198, 163867522 } }, + // Matrix for nskip = 6 * 8 ^ 29: + { { 884974087, 1753139982, 2087168228 }, + { 2226758301, 884974087, 1590955204 }, + { 1886560387, 2226758301, 4000127015 } }, + // Matrix for nskip = 7 * 8 ^ 29: + { { 3230269711, 3957529982, 3575750396 }, + { 3930348525, 3230269711, 2594598825 }, + { 3785901658, 3930348525, 4178374892 } }, + // Matrix for nskip = 1 * 8 ^ 30: + { { 550710036, 500329021, 1075236085 }, + { 356444753, 550710036, 1634965500 }, + { 58733535, 356444753, 1261552815 } }, + // Matrix for nskip = 2 * 8 ^ 30: + { { 708689546, 419139045, 2012018174 }, + { 706488081, 708689546, 1113760995 }, + { 585555005, 706488081, 76092226 } }, + // Matrix for nskip = 3 * 8 ^ 30: + { { 2584730290, 103417098, 2018833769 }, + { 831116151, 2584730290, 1919249397 }, + { 1036497162, 831116151, 2546254144 } }, + // Matrix for nskip = 4 * 8 ^ 30: + { { 1293182265, 3168473803, 366230236 }, + { 3319068849, 1293182265, 1085259665 }, + { 1675229290, 3319068849, 3912300371 } }, + // Matrix for nskip = 5 * 8 ^ 30: + { { 2602420349, 3992244735, 1543754813 }, + { 3770060220, 2602420349, 1407637442 }, + { 944746705, 3770060220, 2920440850 } }, + // Matrix for nskip = 6 * 8 ^ 30: + { { 1601703108, 619857159, 1219413461 }, + { 2824672719, 1601703108, 3707169777 }, + { 3352413650, 2824672719, 1098132331 } }, + // Matrix for nskip = 7 * 8 ^ 30: + { { 3630967154, 3444173778, 3289446159 }, + { 1769199423, 3630967154, 2021155330 }, + { 1478978985, 1769199423, 1976131087 } }, + // Matrix for nskip = 1 * 8 ^ 31: + { { 3186089068, 4188864734, 1211781402 }, + { 756122322, 3186089068, 578262892 }, + { 2518961174, 756122322, 1658665581 } }, + // Matrix for nskip = 2 * 8 ^ 31: + { { 1347291439, 2050427676, 736113023 }, + { 4102191254, 1347291439, 878627148 }, + { 1293500383, 4102191254, 745646810 } }, + // Matrix for nskip = 3 * 8 ^ 31: + { { 1428398286, 758558167, 59314928 }, + { 2615508955, 1428398286, 3061138405 }, + { 1098162878, 2615508955, 2401469211 } }, + // Matrix for nskip = 4 * 8 ^ 31: + { { 4196897331, 3436564969, 1900167098 }, + { 3108887846, 4196897331, 2697923227 }, + { 1405263476, 3108887846, 314631094 } }, + // Matrix for nskip = 5 * 8 ^ 31: + { { 3004743607, 2733058282, 4202297421 }, + { 956778663, 3004743607, 1815192601 }, + { 2211295748, 956778663, 3626831178 } }, + // Matrix for nskip = 6 * 8 ^ 31: + { { 3694919563, 2520419703, 731922800 }, + { 540077867, 3694919563, 2433069844 }, + { 2129238146, 540077867, 301939378 } }, + // Matrix for nskip = 7 * 8 ^ 31: + { { 2475140271, 37335008, 2778457406 }, + { 2217587145, 2475140271, 1363889163 }, + { 135344313, 2217587145, 1707617706 } }, + // Matrix for nskip = 1 * 8 ^ 32: + { { 958383622, 3694638688, 1150087061 }, + { 3770009830, 958383622, 793326651 }, + { 533700213, 3770009830, 1513734026 } }, + // Matrix for nskip = 2 * 8 ^ 32: + { { 4119603367, 3479396923, 3534176399 }, + { 3765397477, 4119603367, 1458031003 }, + { 3380901602, 3765397477, 2684083587 } }, + // Matrix for nskip = 3 * 8 ^ 32: + { { 178016378, 1184002529, 789650986 }, + { 389885259, 178016378, 3729279189 }, + { 1268575347, 389885259, 4091367000 } }, + // Matrix for nskip = 4 * 8 ^ 32: + { { 980937351, 2094378936, 448446028 }, + { 1421333909, 980937351, 3405683645 }, + { 323724368, 1421333909, 338680738 } }, + // Matrix for nskip = 5 * 8 ^ 32: + { { 2381808660, 341372255, 146194193 }, + { 4185254045, 2381808660, 1244677534 }, + { 2006223188, 4185254045, 3589653882 } }, + // Matrix for nskip = 6 * 8 ^ 32: + { { 1104593159, 2457034166, 4243190272 }, + { 2690000574, 1104593159, 3592133108 }, + { 3935039161, 2690000574, 2028886430 } }, + // Matrix for nskip = 7 * 8 ^ 32: + { { 798595991, 3072704016, 1453032677 }, + { 3595149031, 798595991, 1556294726 }, + { 775957906, 3595149031, 208124234 } }, + // Matrix for nskip = 1 * 8 ^ 33: + { { 2942968846, 4293637338, 3549906544 }, + { 527851489, 2942968846, 3852871282 }, + { 4209198933, 527851489, 1091268872 } }, + // Matrix for nskip = 2 * 8 ^ 33: + { { 1975983015, 2092556693, 611187071 }, + { 3982652344, 1975983015, 3001736262 }, + { 2055073597, 3982652344, 1875181995 } }, + // Matrix for nskip = 3 * 8 ^ 33: + { { 1752967931, 1167063522, 3817182484 }, + { 3760899628, 1752967931, 2808655727 }, + { 3110603267, 3760899628, 1832178008 } }, + // Matrix for nskip = 4 * 8 ^ 33: + { { 2970221269, 880904779, 2447465272 }, + { 2888742196, 2970221269, 3521651749 }, + { 3019977656, 2888742196, 2712717326 } }, + // Matrix for nskip = 5 * 8 ^ 33: + { { 604958655, 442191761, 1996070625 }, + { 1269454015, 604958655, 814754560 }, + { 507433046, 1269454015, 2488458391 } }, + // Matrix for nskip = 6 * 8 ^ 33: + { { 710612185, 99734716, 3956229929 }, + { 2137129319, 710612185, 2895847378 }, + { 1727032860, 2137129319, 1001260701 } }, + // Matrix for nskip = 7 * 8 ^ 33: + { { 1066664047, 4152765348, 1734907969 }, + { 2968154336, 1066664047, 2381691001 }, + { 1497199245, 2968154336, 3563839605 } }, + // Matrix for nskip = 1 * 8 ^ 34: + { { 419134859, 2976059897, 747864206 }, + { 4101695717, 419134859, 4264593116 }, + { 2657991148, 4101695717, 2542621682 } }, + // Matrix for nskip = 2 * 8 ^ 34: + { { 4043135299, 1612983166, 1149778656 }, + { 1267010518, 4043135299, 3496325546 }, + { 3094232897, 1267010518, 2949176293 } }, + // Matrix for nskip = 3 * 8 ^ 34: + { { 3214297332, 2846434362, 4106231685 }, + { 1780972559, 3214297332, 1132838092 }, + { 1348023856, 1780972559, 537227984 } }, + // Matrix for nskip = 4 * 8 ^ 34: + { { 3949395794, 1774568686, 2123036003 }, + { 2182983404, 3949395794, 2355671350 }, + { 2820933455, 2182983404, 513963325 } }, + // Matrix for nskip = 5 * 8 ^ 34: + { { 1877604589, 3803366824, 2927718923 }, + { 2817972608, 1877604589, 901177092 }, + { 1008515195, 2817972608, 1900906578 } }, + // Matrix for nskip = 6 * 8 ^ 34: + { { 2247365780, 1508191753, 929996525 }, + { 2014701429, 2247365780, 2906849518 }, + { 1864911773, 2014701429, 634217040 } }, + // Matrix for nskip = 7 * 8 ^ 34: + { { 3200692723, 3246632578, 3558417384 }, + { 733273917, 3200692723, 715293224 }, + { 3878803573, 733273917, 3720987401 } }, + // Matrix for nskip = 1 * 8 ^ 35: + { { 3046911698, 2576744453, 2492729814 }, + { 4277866093, 3046911698, 3146977604 }, + { 2249371766, 4277866093, 3622293976 } }, + // Matrix for nskip = 2 * 8 ^ 35: + { { 1391529818, 423458502, 2587125255 }, + { 3536237833, 1391529818, 985347517 }, + { 157623850, 3536237833, 1015566287 } }, + // Matrix for nskip = 3 * 8 ^ 35: + { { 2768170623, 2671124421, 1038000683 }, + { 2258964805, 2768170623, 3036723158 }, + { 2454977948, 2258964805, 2502325941 } }, + // Matrix for nskip = 4 * 8 ^ 35: + { { 48329260, 2599277669, 821961664 }, + { 902187690, 48329260, 1716556555 }, + { 4019658974, 902187690, 950730510 } }, + // Matrix for nskip = 5 * 8 ^ 35: + { { 3100975771, 1019061132, 1844417430 }, + { 1634016885, 3100975771, 2161076681 }, + { 378757639, 1634016885, 4124897232 } }, + // Matrix for nskip = 6 * 8 ^ 35: + { { 1045387495, 796030826, 1236131839 }, + { 2328291482, 1045387495, 2884310858 }, + { 3863948457, 2328291482, 465921502 } }, + // Matrix for nskip = 7 * 8 ^ 35: + { { 3483511399, 741205873, 1920164372 }, + { 1105604243, 3483511399, 2420741811 }, + { 2484220821, 1105604243, 2513215163 } }, + // Matrix for nskip = 1 * 8 ^ 36: + { { 1318489562, 1530977112, 3713577419 }, + { 4270158447, 1318489562, 1654940598 }, + { 2679964938, 4270158447, 1337075195 } }, + // Matrix for nskip = 2 * 8 ^ 36: + { { 770600793, 3249576224, 3578552768 }, + { 2710443459, 770600793, 2990852339 }, + { 3098163705, 2710443459, 522138188 } }, + // Matrix for nskip = 3 * 8 ^ 36: + { { 3299888517, 1806316064, 2474407987 }, + { 3432253975, 3299888517, 3480703284 }, + { 201692417, 3432253975, 1711417284 } }, + // Matrix for nskip = 4 * 8 ^ 36: + { { 2803285489, 1922250286, 3164022812 }, + { 477609731, 2803285489, 2140252218 }, + { 2252852611, 477609731, 3058519788 } }, + // Matrix for nskip = 5 * 8 ^ 36: + { { 3735324161, 860809210, 2792496593 }, + { 1613420642, 3735324161, 651730634 }, + { 3412387271, 1613420642, 2796594703 } }, + // Matrix for nskip = 6 * 8 ^ 36: + { { 993539593, 3499265007, 3772074010 }, + { 3213913829, 993539593, 3655831787 }, + { 2561980091, 3213913829, 2164990937 } }, + // Matrix for nskip = 7 * 8 ^ 36: + { { 76754721, 818311023, 1258273773 }, + { 2914546594, 76754721, 3007787703 }, + { 1554324281, 2914546594, 1645121444 } }, + // Matrix for nskip = 1 * 8 ^ 37: + { { 208329741, 3633562083, 3548346666 }, + { 3892091460, 208329741, 516833304 }, + { 3440632377, 3892091460, 1638833719 } }, + // Matrix for nskip = 2 * 8 ^ 37: + { { 1816075033, 3570111203, 959489356 }, + { 3482051486, 1816075033, 861657108 }, + { 3119495098, 3482051486, 2576849579 } }, + // Matrix for nskip = 3 * 8 ^ 37: + { { 955576990, 607798602, 220457899 }, + { 760121425, 955576990, 1155400464 }, + { 1209136348, 760121425, 1165671753 } }, + // Matrix for nskip = 4 * 8 ^ 37: + { { 4240216888, 2891584407, 2102314945 }, + { 4064489450, 4240216888, 1427441010 }, + { 2441164913, 4064489450, 3558527186 } }, + // Matrix for nskip = 5 * 8 ^ 37: + { { 3943073787, 2113696223, 3840029496 }, + { 42559030, 3943073787, 2203932271 }, + { 638717597, 42559030, 3208053933 } }, + // Matrix for nskip = 6 * 8 ^ 37: + { { 714331518, 510361535, 3438751245 }, + { 2783614947, 714331518, 666348656 }, + { 4028058908, 2783614947, 2994150339 } }, + // Matrix for nskip = 7 * 8 ^ 37: + { { 3978295779, 1441779930, 4249164235 }, + { 1006134725, 3978295779, 2022224066 }, + { 1257228544, 1006134725, 3563676111 } }, + // Matrix for nskip = 1 * 8 ^ 38: + { { 2918371295, 65155283, 3469357011 }, + { 3579773554, 2918371295, 3494391959 }, + { 3266584309, 3579773554, 3837485479 } }, + // Matrix for nskip = 2 * 8 ^ 38: + { { 2959420453, 1365016881, 4082486022 }, + { 236489012, 2959420453, 3802558529 }, + { 2687043642, 236489012, 2547086826 } }, + // Matrix for nskip = 3 * 8 ^ 38: + { { 3501988208, 1843500325, 3464182128 }, + { 969269805, 3501988208, 2232088910 }, + { 3829792024, 969269805, 2334756085 } }, + // Matrix for nskip = 4 * 8 ^ 38: + { { 4185325422, 2762854843, 3200044912 }, + { 3664909559, 4185325422, 3543921700 }, + { 4240262918, 3664909559, 2853212443 } }, + // Matrix for nskip = 5 * 8 ^ 38: + { { 3870531367, 2625370600, 1928035826 }, + { 1477778653, 3870531367, 4167218005 }, + { 2810379745, 1477778653, 1547435981 } }, + // Matrix for nskip = 6 * 8 ^ 38: + { { 2166942438, 2045317959, 2862960125 }, + { 1192305592, 2166942438, 2202186359 }, + { 1282445014, 1192305592, 3680855685 } }, + // Matrix for nskip = 7 * 8 ^ 38: + { { 4183888729, 1630438655, 1622555680 }, + { 841523235, 4183888729, 266662726 }, + { 1888300357, 841523235, 553070804 } }, + // Matrix for nskip = 1 * 8 ^ 39: + { { 2618500928, 4237264351, 1470046497 }, + { 1893990098, 2618500928, 2982567031 }, + { 3017062825, 1893990098, 3195556801 } }, + // Matrix for nskip = 2 * 8 ^ 39: + { { 1868464655, 3407681142, 1652841784 }, + { 1678569574, 1868464655, 4162480901 }, + { 1477016185, 1678569574, 4145063890 } }, + // Matrix for nskip = 3 * 8 ^ 39: + { { 346858981, 2885211332, 1550050752 }, + { 3168708136, 346858981, 2121517268 }, + { 696413464, 3168708136, 2779761666 } }, + // Matrix for nskip = 4 * 8 ^ 39: + { { 792188465, 4251338402, 2219407026 }, + { 3840340879, 792188465, 3493367465 }, + { 2979958414, 3840340879, 2338974139 } }, + // Matrix for nskip = 5 * 8 ^ 39: + { { 3859433262, 3764728773, 1297631730 }, + { 3833824001, 3859433262, 1333287789 }, + { 1909447704, 3833824001, 2135933046 } }, + // Matrix for nskip = 6 * 8 ^ 39: + { { 102264893, 4038432252, 2717349223 }, + { 709433989, 102264893, 1807326569 }, + { 2997676666, 709433989, 3722753261 } }, + // Matrix for nskip = 7 * 8 ^ 39: + { { 4020257258, 1217293203, 2346103599 }, + { 3809824315, 4020257258, 576285090 }, + { 3162683019, 3809824315, 2652264596 } }, + // Matrix for nskip = 1 * 8 ^ 40: + { { 478845700, 2378167062, 882114621 }, + { 1674533845, 478845700, 3572905305 }, + { 3571222880, 1674533845, 1242316901 } }, + // Matrix for nskip = 2 * 8 ^ 40: + { { 2636090868, 1972761498, 71690719 }, + { 1228103463, 2636090868, 1280685025 }, + { 3741735502, 1228103463, 994061750 } }, + // Matrix for nskip = 3 * 8 ^ 40: + { { 2765592972, 3759047976, 2089192298 }, + { 2592791249, 2765592972, 2079317731 }, + { 3195761319, 2592791249, 913428082 } }, + // Matrix for nskip = 4 * 8 ^ 40: + { { 1156725261, 1100755307, 221922891 }, + { 2892200461, 1156725261, 1505716533 }, + { 2287613563, 2892200461, 3689457190 } }, + // Matrix for nskip = 5 * 8 ^ 40: + { { 716602832, 851112058, 2726490354 }, + { 328778061, 716602832, 2662750501 }, + { 2300190858, 328778061, 2031908929 } }, + // Matrix for nskip = 6 * 8 ^ 40: + { { 131535614, 3548535605, 1837882588 }, + { 3257415168, 131535614, 1374937136 }, + { 1879184234, 3257415168, 167534374 } }, + // Matrix for nskip = 7 * 8 ^ 40: + { { 3131954528, 4223897546, 515553914 }, + { 326215900, 3131954528, 644217952 }, + { 934922655, 326215900, 2645770575 } }, + // Matrix for nskip = 1 * 8 ^ 41: + { { 1387244644, 3135090808, 1243609165 }, + { 1724967466, 1387244644, 3296353235 }, + { 1064364031, 1724967466, 2107521044 } }, + // Matrix for nskip = 2 * 8 ^ 41: + { { 2822471992, 2034317853, 2071407475 }, + { 170903528, 2822471992, 1322162887 }, + { 2524982332, 170903528, 2656231333 } }, + // Matrix for nskip = 3 * 8 ^ 41: + { { 2401421275, 3219909065, 1167519964 }, + { 3200856372, 2401421275, 2651362201 }, + { 3150793696, 3200856372, 3740263529 } }, + // Matrix for nskip = 4 * 8 ^ 41: + { { 3653936868, 3893194049, 2484299328 }, + { 1313746234, 3653936868, 1705346273 }, + { 1397638018, 1313746234, 4015529545 } }, + // Matrix for nskip = 5 * 8 ^ 41: + { { 762850190, 2502708647, 3030789377 }, + { 605169915, 762850190, 2517301940 }, + { 2651641442, 605169915, 3739297479 } }, + // Matrix for nskip = 6 * 8 ^ 41: + { { 4185157227, 3109351418, 2907095532 }, + { 3981440524, 4185157227, 2447807956 }, + { 1358765607, 3981440524, 2947483756 } }, + // Matrix for nskip = 7 * 8 ^ 41: + { { 616351240, 2708761949, 3510102453 }, + { 1192816102, 616351240, 3430261471 }, + { 3769975746, 1192816102, 1092752722 } }, + // Matrix for nskip = 1 * 8 ^ 42: + { { 4129760842, 1671665759, 1677834656 }, + { 3200005334, 4129760842, 3486207172 }, + { 2850728736, 3200005334, 3076201597 } }, + // Matrix for nskip = 2 * 8 ^ 42: + { { 1464411153, 277697599, 1610723613 }, + { 32183930, 1464411153, 1022607788 }, + { 2824425944, 32183930, 2093834863 } }, + // Matrix for nskip = 3 * 8 ^ 42: + { { 4289888328, 3225021158, 546274137 }, + { 3161813725, 4289888328, 3178255601 }, + { 811227116, 3161813725, 2040329321 } }, + // Matrix for nskip = 4 * 8 ^ 42: + { { 3492361727, 1027004383, 3167429889 }, + { 3674905362, 3492361727, 3572939265 }, + { 4270409313, 3674905362, 698814233 } }, + // Matrix for nskip = 5 * 8 ^ 42: + { { 1024068271, 2798745077, 2659447825 }, + { 2040144100, 1024068271, 1035060877 }, + { 2866843005, 2040144100, 787687659 } }, + // Matrix for nskip = 6 * 8 ^ 42: + { { 2906151318, 3986151835, 2581649800 }, + { 571744464, 2906151318, 1834943086 }, + { 3448634312, 571744464, 290967548 } }, + // Matrix for nskip = 7 * 8 ^ 42: + { { 1570041711, 1880130578, 2514738078 }, + { 3388141786, 1570041711, 744775425 }, + { 2735736928, 3388141786, 964597855 } }, + // Matrix for nskip = 1 * 8 ^ 43: + { { 880482061, 205175925, 4070445105 }, + { 2208329119, 880482061, 1933248566 }, + { 3741227945, 2208329119, 3962062826 } }, + // Matrix for nskip = 2 * 8 ^ 43: + { { 4184605179, 1189429800, 567967482 }, + { 107217966, 4184605179, 784865788 }, + { 549462420, 107217966, 3134382704 } }, + // Matrix for nskip = 3 * 8 ^ 43: + { { 1386364785, 4079260578, 3001857777 }, + { 3010784539, 1386364785, 3667065093 }, + { 3692171012, 3010784539, 2361530061 } }, + // Matrix for nskip = 4 * 8 ^ 43: + { { 2732536445, 1231107067, 3374588386 }, + { 409954030, 2732536445, 1044831206 }, + { 3398162498, 409954030, 3505648581 } }, + // Matrix for nskip = 5 * 8 ^ 43: + { { 3249719425, 4215633308, 1637240461 }, + { 151877124, 3249719425, 2638755179 }, + { 3634975465, 151877124, 1546467979 } }, + // Matrix for nskip = 6 * 8 ^ 43: + { { 2408251701, 89238831, 4165007723 }, + { 4262743528, 2408251701, 4114669800 }, + { 2878757823, 4262743528, 3182943863 } }, + // Matrix for nskip = 7 * 8 ^ 43: + { { 1831049905, 2380192587, 325575207 }, + { 2045407448, 1831049905, 3463310486 }, + { 1637651789, 2045407448, 1889914987 } }, + // Matrix for nskip = 1 * 8 ^ 44: + { { 2169560691, 1076348534, 637306236 }, + { 3704346564, 2169560691, 293694496 }, + { 632453145, 3704346564, 1609425246 } }, + // Matrix for nskip = 2 * 8 ^ 44: + { { 372115891, 3928812480, 2830541169 }, + { 3056527841, 372115891, 1924239834 }, + { 3044937468, 3056527841, 547142630 } }, + // Matrix for nskip = 3 * 8 ^ 44: + { { 3652440052, 1383186997, 3140353867 }, + { 1157890357, 3652440052, 3280219833 }, + { 2953685245, 1157890357, 481162011 } }, + // Matrix for nskip = 4 * 8 ^ 44: + { { 1660852083, 3635660815, 1389092450 }, + { 1025573319, 1660852083, 3276803366 }, + { 4036331438, 1025573319, 4092197741 } }, + // Matrix for nskip = 5 * 8 ^ 44: + { { 2683005143, 1323793242, 1291869629 }, + { 2903240813, 2683005143, 3854329533 }, + { 2695585089, 2903240813, 1426976484 } }, + // Matrix for nskip = 6 * 8 ^ 44: + { { 56767734, 116994667, 111909274 }, + { 3730950473, 56767734, 2191610434 }, + { 1091419714, 3730950473, 718571338 } }, + // Matrix for nskip = 7 * 8 ^ 44: + { { 336318787, 391538001, 10025372 }, + { 3157633492, 336318787, 2821500332 }, + { 3413552779, 3157633492, 4255875513 } }, + // Matrix for nskip = 1 * 8 ^ 45: + { { 1360732901, 2887812973, 4101068693 }, + { 52572783, 1360732901, 112458461 }, + { 2636566855, 52572783, 1136777988 } }, + // Matrix for nskip = 2 * 8 ^ 45: + { { 3455696508, 536919193, 3978804036 }, + { 3094157668, 3455696508, 3821833900 }, + { 2278849016, 3094157668, 2531965909 } }, + // Matrix for nskip = 3 * 8 ^ 45: + { { 105839550, 1126024816, 287198647 }, + { 351807867, 105839550, 643672297 }, + { 1483330368, 351807867, 3781751861 } }, + // Matrix for nskip = 4 * 8 ^ 45: + { { 2125991744, 890897326, 3790557569 }, + { 1433592392, 2125991744, 3671109604 }, + { 808215503, 1433592392, 2446306581 } }, + // Matrix for nskip = 5 * 8 ^ 45: + { { 3640380877, 422210679, 1510633752 }, + { 1569172639, 3640380877, 3192250064 }, + { 1376060847, 1569172639, 2027936709 } }, + // Matrix for nskip = 6 * 8 ^ 45: + { { 3177388361, 1344488735, 2994552097 }, + { 284988983, 3177388361, 3227966904 }, + { 2044803401, 284988983, 4277058832 } }, + // Matrix for nskip = 7 * 8 ^ 45: + { { 3412413108, 4186230758, 3922996456 }, + { 3683836901, 3412413108, 271458827 }, + { 3964969101, 3683836901, 539759068 } }, + // Matrix for nskip = 1 * 8 ^ 46: + { { 3524411799, 932865240, 1838275365 }, + { 1789634890, 3524411799, 4130736474 }, + { 2252266098, 1789634890, 3048775967 } }, + // Matrix for nskip = 2 * 8 ^ 46: + { { 1773339925, 948403862, 1999624391 }, + { 983864203, 1773339925, 3734776305 }, + { 314407045, 983864203, 2648614071 } }, + // Matrix for nskip = 3 * 8 ^ 46: + { { 1928167136, 2078532030, 1690025039 }, + { 2529043017, 1928167136, 1858653225 }, + { 2142588179, 2529043017, 2188623418 } }, + // Matrix for nskip = 4 * 8 ^ 46: + { { 321802921, 1099164995, 2112167358 }, + { 3760936985, 321802921, 1003573324 }, + { 3758858458, 3760936985, 4014658840 } }, + // Matrix for nskip = 5 * 8 ^ 46: + { { 774593807, 1711411238, 3653945922 }, + { 1751249890, 774593807, 10024535 }, + { 9872042, 1751249890, 2762944894 } }, + // Matrix for nskip = 6 * 8 ^ 46: + { { 2825735696, 1396615016, 3702967335 }, + { 3652693925, 2825735696, 4120492766 }, + { 1992385943, 3652693925, 686943862 } }, + // Matrix for nskip = 7 * 8 ^ 46: + { { 2314946087, 4102352240, 989909889 }, + { 459855750, 2314946087, 1424771850 }, + { 1469834717, 459855750, 2094187769 } }, + // Matrix for nskip = 1 * 8 ^ 47: + { { 2196438580, 805386227, 4266375092 }, + { 4124675351, 2196438580, 2527961345 }, + { 94452540, 4124675351, 2825656399 } }, + // Matrix for nskip = 2 * 8 ^ 47: + { { 66735368, 2228005807, 4186703168 }, + { 2624855312, 66735368, 2708679078 }, + { 4098470056, 2624855312, 1773862183 } }, + // Matrix for nskip = 3 * 8 ^ 47: + { { 320933009, 1915174474, 3744070526 }, + { 562558814, 320933009, 1706424966 }, + { 413766233, 562558814, 2881230326 } }, + // Matrix for nskip = 4 * 8 ^ 47: + { { 3072642883, 2746897053, 2690305546 }, + { 1105106652, 3072642883, 4047666135 }, + { 2862886282, 1105106652, 3597347398 } }, + // Matrix for nskip = 5 * 8 ^ 47: + { { 1498353481, 3428325510, 1424606567 }, + { 372840925, 1498353481, 1901161856 }, + { 1201903815, 372840925, 1622747589 } }, + // Matrix for nskip = 6 * 8 ^ 47: + { { 3754310983, 2829438112, 3947637114 }, + { 2617184648, 3754310983, 3119630359 }, + { 2102395010, 2617184648, 2313448358 } }, + // Matrix for nskip = 7 * 8 ^ 47: + { { 2033651727, 3918276995, 2324222273 }, + { 2517499860, 2033651727, 3237758154 }, + { 3966641526, 2517499860, 2296152269 } }, + // Matrix for nskip = 1 * 8 ^ 48: + { { 232906611, 3873338256, 4051554873 }, + { 3027413363, 232906611, 3159432673 }, + { 3872967050, 3027413363, 987156327 } }, + // Matrix for nskip = 2 * 8 ^ 48: + { { 1160686753, 3676603152, 1635979789 }, + { 1447386846, 1160686753, 2670438424 }, + { 816212890, 1447386846, 4288868534 } }, + // Matrix for nskip = 3 * 8 ^ 48: + { { 232406022, 1438391315, 351811028 }, + { 792615675, 232406022, 2249558877 }, + { 4000461186, 792615675, 3773572468 } }, + // Matrix for nskip = 4 * 8 ^ 48: + { { 3825238244, 1445162354, 2362389441 }, + { 3440193648, 3825238244, 3520937545 }, + { 2652790808, 3440193648, 405299994 } }, + // Matrix for nskip = 5 * 8 ^ 48: + { { 1153297111, 1584881761, 3755481813 }, + { 2565782177, 1153297111, 595979811 }, + { 3520546605, 2565782177, 1485833084 } }, + // Matrix for nskip = 6 * 8 ^ 48: + { { 2264796250, 1995295374, 4156333842 }, + { 4182411213, 2264796250, 3692855966 }, + { 2398102705, 4182411213, 135106935 } }, + // Matrix for nskip = 7 * 8 ^ 48: + { { 1510709042, 3654924984, 4137143940 }, + { 3411234559, 1510709042, 3713963703 }, + { 3111723660, 3411234559, 3580357515 } }, + // Matrix for nskip = 1 * 8 ^ 49: + { { 1984094858, 532165989, 2027397575 }, + { 1455977136, 1984094858, 2433255524 }, + { 1039994763, 1455977136, 2069333087 } }, + // Matrix for nskip = 2 * 8 ^ 49: + { { 3680843319, 2332949611, 3516795313 }, + { 2033851810, 3680843319, 3843367307 }, + { 3686294589, 2033851810, 3912995069 } }, + // Matrix for nskip = 3 * 8 ^ 49: + { { 2570307024, 165497191, 3880130435 }, + { 540713030, 2570307024, 1096034689 }, + { 3859799631, 540713030, 3714945286 } }, + // Matrix for nskip = 4 * 8 ^ 49: + { { 967423689, 1724183394, 635932799 }, + { 641380480, 967423689, 2145297779 }, + { 1723000412, 641380480, 455633660 } }, + // Matrix for nskip = 5 * 8 ^ 49: + { { 2807559499, 2180128950, 1968769828 }, + { 1885526032, 2807559499, 3568246807 }, + { 1874951461, 1885526032, 2399805320 } }, + // Matrix for nskip = 6 * 8 ^ 49: + { { 743327961, 3817146458, 2078921540 }, + { 752843557, 743327961, 3382133383 }, + { 1546279541, 752843557, 4269455046 } }, + // Matrix for nskip = 7 * 8 ^ 49: + { { 355924266, 3865252236, 3092467664 }, + { 2414940441, 355924266, 3290161562 }, + { 493050060, 2414940441, 2727946913 } }, + // Matrix for nskip = 1 * 8 ^ 50: + { { 2130938335, 1534972306, 2511584766 }, + { 273828453, 2130938335, 3112810093 }, + { 4084843716, 273828453, 1399334152 } }, + // Matrix for nskip = 2 * 8 ^ 50: + { { 168278549, 541167592, 190177712 }, + { 403188859, 168278549, 2092073970 }, + { 58789558, 403188859, 2777887189 } }, + // Matrix for nskip = 3 * 8 ^ 50: + { { 664028138, 360061317, 3240810721 }, + { 3427777045, 664028138, 589375738 }, + { 1247469758, 3427777045, 4103288151 } }, + // Matrix for nskip = 4 * 8 ^ 50: + { { 634843389, 4082275720, 2092828966 }, + { 351187677, 634843389, 1312056270 }, + { 3347241070, 351187677, 2417192332 } }, + // Matrix for nskip = 5 * 8 ^ 50: + { { 3269976890, 3103127568, 907107523 }, + { 3154851935, 3269976890, 1078491382 }, + { 1129461097, 3154851935, 3960596933 } }, + // Matrix for nskip = 6 * 8 ^ 50: + { { 1155790154, 89494164, 1039763155 }, + { 393005763, 1155790154, 2648470077 }, + { 2830413843, 393005763, 1280581785 } }, + // Matrix for nskip = 7 * 8 ^ 50: + { { 2340682307, 3775335435, 3604492026 }, + { 4198859651, 2340682307, 1392463605 }, + { 1917833692, 4198859651, 2536657316 } }, + // Matrix for nskip = 1 * 8 ^ 51: + { { 443276110, 1113643788, 271102234 }, + { 3083745876, 443276110, 3370743767 }, + { 4200577503, 3083745876, 3298601960 } }, + // Matrix for nskip = 2 * 8 ^ 51: + { { 3533393557, 764977733, 3400275098 }, + { 144639933, 3533393557, 2646475951 }, + { 77963866, 144639933, 3794766611 } }, + // Matrix for nskip = 3 * 8 ^ 51: + { { 914011908, 1379977154, 3635095314 }, + { 4096393357, 914011908, 962932343 }, + { 410940557, 4096393357, 2300259911 } }, + // Matrix for nskip = 4 * 8 ^ 51: + { { 4064854722, 1198665008, 2872196602 }, + { 3274748603, 4064854722, 4164637970 }, + { 4238693771, 3274748603, 1981721347 } }, + // Matrix for nskip = 5 * 8 ^ 51: + { { 658075764, 868441731, 631337149 }, + { 3000164892, 658075764, 3213078611 }, + { 2494369285, 3000164892, 1969086166 } }, + // Matrix for nskip = 6 * 8 ^ 51: + { { 1202027740, 1218291611, 251455117 }, + { 1904530179, 1202027740, 1121637945 }, + { 2014861157, 1904530179, 3331497439 } }, + // Matrix for nskip = 7 * 8 ^ 51: + { { 860183345, 3722900937, 2577917907 }, + { 184407828, 860183345, 3959662009 }, + { 1130199284, 184407828, 1996334021 } }, + // Matrix for nskip = 1 * 8 ^ 52: + { { 2279220396, 2355957139, 1417574285 }, + { 885864931, 2279220396, 1344421653 }, + { 1895527787, 885864931, 3726919367 } }, + // Matrix for nskip = 2 * 8 ^ 52: + { { 2898100178, 2427331008, 348923199 }, + { 3175444953, 2898100178, 4290541487 }, + { 246118669, 3175444953, 3410622769 } }, + // Matrix for nskip = 3 * 8 ^ 52: + { { 55373162, 3987120186, 2739617092 }, + { 488341106, 55373162, 3877861726 }, + { 468535899, 488341106, 2277317349 } }, + // Matrix for nskip = 4 * 8 ^ 52: + { { 284442065, 4064194676, 2295560707 }, + { 4182706556, 284442065, 3696899246 }, + { 1201342255, 4182706556, 1145356382 } }, + // Matrix for nskip = 5 * 8 ^ 52: + { { 854963956, 3894612396, 2185360428 }, + { 3161673906, 854963956, 1200638109 }, + { 808492591, 3161673906, 1983142708 } }, + // Matrix for nskip = 6 * 8 ^ 52: + { { 2146747531, 896368240, 1430380976 }, + { 1613992473, 2146747531, 901075807 }, + { 2390399884, 1613992473, 270201416 } }, + // Matrix for nskip = 7 * 8 ^ 52: + { { 1033390767, 4214343810, 3176316290 }, + { 238941078, 1033390767, 957806905 }, + { 3045719234, 238941078, 3992043804 } }, + // Matrix for nskip = 1 * 8 ^ 53: + { { 656615546, 442908965, 3724738272 }, + { 1624967553, 656615546, 798014134 }, + { 1157949454, 1624967553, 496247378 } }, + // Matrix for nskip = 2 * 8 ^ 53: + { { 265689579, 675056541, 3009083380 }, + { 3820679930, 265689579, 2961990151 }, + { 562287964, 3820679930, 1853486796 } }, + // Matrix for nskip = 3 * 8 ^ 53: + { { 3115797761, 1090045712, 399035362 }, + { 452658959, 3115797761, 3053809839 }, + { 3970000518, 452658959, 2899502994 } }, + // Matrix for nskip = 4 * 8 ^ 53: + { { 1675739167, 2319843005, 760605578 }, + { 4161492847, 1675739167, 226142150 }, + { 1017447188, 4161492847, 3431158427 } }, + // Matrix for nskip = 5 * 8 ^ 53: + { { 1814415714, 3446998641, 1659100687 }, + { 299018378, 1814415714, 3661851369 }, + { 2777381296, 299018378, 730677422 } }, + // Matrix for nskip = 6 * 8 ^ 53: + { { 497640593, 3005114205, 2309875696 }, + { 3522463659, 497640593, 590519806 }, + { 855175401, 3522463659, 1973739759 } }, + // Matrix for nskip = 7 * 8 ^ 53: + { { 2668363194, 344864589, 270881279 }, + { 981182918, 2668363194, 1986955069 }, + { 956851812, 981182918, 3901969881 } }, + // Matrix for nskip = 1 * 8 ^ 54: + { { 1759873736, 2334568602, 2154570180 }, + { 1812793060, 1759873736, 2111094408 }, + { 1168460586, 1812793060, 2495653141 } }, + // Matrix for nskip = 2 * 8 ^ 54: + { { 317621194, 868104288, 664971082 }, + { 2340275074, 317621194, 2168960688 }, + { 725706104, 2340275074, 3532023115 } }, + // Matrix for nskip = 3 * 8 ^ 54: + { { 3585587043, 2378713321, 2463381051 }, + { 2919944362, 3585587043, 1464119531 }, + { 3588451359, 2919944362, 1912059035 } }, + // Matrix for nskip = 4 * 8 ^ 54: + { { 3926931954, 2907684453, 615601328 }, + { 1132340715, 3926931954, 676995757 }, + { 1154819290, 1132340715, 1662727700 } }, + // Matrix for nskip = 5 * 8 ^ 54: + { { 918221359, 2912639129, 1883551759 }, + { 4114315731, 918221359, 1703365082 }, + { 2391341541, 4114315731, 3946112236 } }, + // Matrix for nskip = 6 * 8 ^ 54: + { { 2495152894, 362016218, 2659927506 }, + { 1721141770, 2495152894, 2577006096 }, + { 73701594, 1721141770, 2683266250 } }, + // Matrix for nskip = 7 * 8 ^ 54: + { { 1978338540, 424481557, 341918993 }, + { 3862312182, 1978338540, 436776944 }, + { 566398653, 3862312182, 1196282660 } }, + // Matrix for nskip = 1 * 8 ^ 55: + { { 3921782078, 3376494857, 2969567377 }, + { 475345024, 3921782078, 4206379953 }, + { 1795936544, 475345024, 934679595 } }, + // Matrix for nskip = 2 * 8 ^ 55: + { { 3119292228, 741613041, 2083352304 }, + { 1047885963, 3119292228, 1581078542 }, + { 1065969969, 1047885963, 661718928 } }, + // Matrix for nskip = 3 * 8 ^ 55: + { { 3193382049, 573569291, 3880461974 }, + { 1401117517, 3193382049, 335339494 }, + { 2267936793, 1401117517, 2098160992 } }, + // Matrix for nskip = 4 * 8 ^ 55: + { { 3643472111, 2870554228, 3995474529 }, + { 3804264051, 3643472111, 1366457944 }, + { 1246805564, 3804264051, 993186530 } }, + // Matrix for nskip = 5 * 8 ^ 55: + { { 2693567720, 1775121226, 3619720132 }, + { 1859333754, 2693567720, 2377603858 }, + { 2682882800, 1859333754, 532216705 } }, + // Matrix for nskip = 6 * 8 ^ 55: + { { 2520305729, 3279882298, 2663387463 }, + { 1160802169, 2520305729, 1363372142 }, + { 92806587, 1160802169, 3842743664 } }, + // Matrix for nskip = 7 * 8 ^ 55: + { { 1402382861, 2128689614, 967911190 }, + { 1124729601, 1402382861, 1908361865 }, + { 2731098528, 1124729601, 3607037865 } }, + // Matrix for nskip = 1 * 8 ^ 56: + { { 796711791, 3878204845, 3160293932 }, + { 255632881, 796711791, 3778927111 }, + { 3472564181, 255632881, 388382377 } }, + // Matrix for nskip = 2 * 8 ^ 56: + { { 1776984101, 1742284034, 3449763933 }, + { 1349354417, 1776984101, 1264780832 }, + { 715722511, 1349354417, 1213319489 } }, + // Matrix for nskip = 3 * 8 ^ 56: + { { 3231284907, 2981539575, 3476263944 }, + { 3070932389, 3231284907, 4183678140 }, + { 4073569309, 3070932389, 1095273395 } }, + // Matrix for nskip = 4 * 8 ^ 56: + { { 4261866865, 1914382786, 201872335 }, + { 614207188, 4261866865, 1853554849 }, + { 2046042882, 614207188, 3193186353 } }, + // Matrix for nskip = 5 * 8 ^ 56: + { { 4179922982, 2821238835, 3720886954 }, + { 1712333408, 4179922982, 2683472927 }, + { 2838663503, 1712333408, 3967303913 } }, + // Matrix for nskip = 6 * 8 ^ 56: + { { 2701381139, 3664845069, 2023182114 }, + { 2420177830, 2701381139, 1924402503 }, + { 3429706463, 2420177830, 2803635446 } }, + // Matrix for nskip = 7 * 8 ^ 56: + { { 4122275824, 2032046756, 1051494202 }, + { 2221023672, 4122275824, 722305627 }, + { 547107197, 2221023672, 2228432272 } }, + // Matrix for nskip = 1 * 8 ^ 57: + { { 2210205512, 2847073169, 3324925707 }, + { 1251969297, 2210205512, 3491451503 }, + { 470400916, 1251969297, 2184392547 } }, + // Matrix for nskip = 2 * 8 ^ 57: + { { 1523590942, 2391111113, 68341529 }, + { 295466806, 1523590942, 4143310876 }, + { 3527253079, 295466806, 4059123142 } }, + // Matrix for nskip = 3 * 8 ^ 57: + { { 3667945349, 431655152, 2687669798 }, + { 1584045661, 3667945349, 2642149990 }, + { 2169193555, 1584045661, 2115882504 } }, + // Matrix for nskip = 4 * 8 ^ 57: + { { 1406902110, 3735012720, 1774518130 }, + { 1814959027, 1406902110, 1560544267 }, + { 346472965, 1814959027, 964257199 } }, + // Matrix for nskip = 5 * 8 ^ 57: + { { 2718256179, 4102604932, 4277499868 }, + { 3681834937, 2718256179, 4201441381 }, + { 1715953284, 3681834937, 1112580533 } }, + // Matrix for nskip = 6 * 8 ^ 57: + { { 992368492, 2710608111, 2674694909 }, + { 3754191262, 992368492, 1060465580 }, + { 2574962339, 3754191262, 60540513 } }, + // Matrix for nskip = 7 * 8 ^ 57: + { { 1719209658, 2756912996, 4193028814 }, + { 4256860235, 1719209658, 3552491408 }, + { 1070852068, 4256860235, 3586319939 } }, + // Matrix for nskip = 1 * 8 ^ 58: + { { 855309653, 4208503105, 1518467541 }, + { 2025248418, 855309653, 4148125749 }, + { 1349947330, 2025248418, 1168504873 } }, + // Matrix for nskip = 2 * 8 ^ 58: + { { 2375338156, 3629519168, 409696181 }, + { 252401654, 2375338156, 3992097193 }, + { 2793725401, 252401654, 1350184085 } }, + // Matrix for nskip = 3 * 8 ^ 58: + { { 2856909490, 1191427722, 3088217623 }, + { 3529719882, 2856909490, 204704202 }, + { 1918223997, 3529719882, 2282426993 } }, + // Matrix for nskip = 4 * 8 ^ 58: + { { 873141039, 3885583138, 361604799 }, + { 3554143374, 873141039, 894746180 }, + { 1919765327, 3554143374, 876210854 } }, + // Matrix for nskip = 5 * 8 ^ 58: + { { 652228317, 107568976, 2576316170 }, + { 790910548, 652228317, 1352723275 }, + { 1091561936, 790910548, 1291982092 } }, + // Matrix for nskip = 6 * 8 ^ 58: + { { 3452179482, 4206785268, 2363956864 }, + { 2619693001, 3452179482, 54522393 }, + { 4241208723, 2619693001, 2583115784 } }, + // Matrix for nskip = 7 * 8 ^ 58: + { { 547180410, 904354606, 3387638559 }, + { 2429997228, 547180410, 1350013492 }, + { 4258335371, 2429997228, 1689405508 } }, + // Matrix for nskip = 1 * 8 ^ 59: + { { 246368794, 1703793169, 2317362874 }, + { 2300930144, 246368794, 2560214589 }, + { 2016163623, 2300930144, 1504276775 } }, + // Matrix for nskip = 2 * 8 ^ 59: + { { 1574610921, 2147546631, 4103450226 }, + { 107416526, 1574610921, 1773803959 }, + { 1402542742, 107416526, 550063800 } }, + // Matrix for nskip = 3 * 8 ^ 59: + { { 2364572364, 3566983915, 468574833 }, + { 3825719596, 2364572364, 3679744745 }, + { 2445832362, 3825719596, 1752846470 } }, + // Matrix for nskip = 4 * 8 ^ 59: + { { 363388665, 592194244, 1746615522 }, + { 2637234667, 363388665, 4031408742 }, + { 2895130475, 2637234667, 296510335 } }, + // Matrix for nskip = 5 * 8 ^ 59: + { { 208003776, 91247399, 1566440482 }, + { 2144494056, 208003776, 1022614336 }, + { 2439698058, 2144494056, 4292230862 } }, + // Matrix for nskip = 6 * 8 ^ 59: + { { 2823846657, 4257316854, 3340983277 }, + { 218486499, 2823846657, 3142931989 }, + { 2351513088, 218486499, 3471595726 } }, + // Matrix for nskip = 7 * 8 ^ 59: + { { 3562083579, 3058668461, 1588504573 }, + { 2047897620, 3562083579, 1674831117 }, + { 965798968, 2047897620, 1212961148 } }, + // Matrix for nskip = 1 * 8 ^ 60: + { { 3997368560, 3047771871, 3178383826 }, + { 1160174754, 3997368560, 4027094919 }, + { 1234984211, 1160174754, 4226264344 } }, + // Matrix for nskip = 2 * 8 ^ 60: + { { 3303179301, 4243968063, 3235964171 }, + { 1776841674, 3303179301, 2867287469 }, + { 1500495759, 1776841674, 1708226553 } }, + // Matrix for nskip = 3 * 8 ^ 60: + { { 1859001036, 2962890971, 2391336228 }, + { 1694166096, 1859001036, 593465055 }, + { 1377070160, 1694166096, 2513927224 } }, + // Matrix for nskip = 4 * 8 ^ 60: + { { 1482944153, 3192311574, 354466071 }, + { 3932773012, 1482944153, 389193591 }, + { 3350181058, 3932773012, 3398059015 } }, + // Matrix for nskip = 5 * 8 ^ 60: + { { 3478906695, 565159378, 3563812138 }, + { 2637114657, 3478906695, 1117546206 }, + { 909882870, 2637114657, 2819889512 } }, + // Matrix for nskip = 6 * 8 ^ 60: + { { 3406907174, 3949116664, 536198867 }, + { 3969663510, 3406907174, 915271858 }, + { 1537382635, 3969663510, 1154112679 } }, + // Matrix for nskip = 7 * 8 ^ 60: + { { 1488624292, 2799268852, 4148140705 }, + { 2326140461, 1488624292, 2413540258 }, + { 3071215524, 2326140461, 1918378675 } }, + // Matrix for nskip = 1 * 8 ^ 61: + { { 640968550, 3226860971, 922372912 }, + { 1254989667, 640968550, 2383815228 }, + { 2027371896, 1254989667, 2925300409 } }, + // Matrix for nskip = 2 * 8 ^ 61: + { { 2313146046, 3910187183, 1377591475 }, + { 1689291784, 2313146046, 4255405993 }, + { 1650609719, 1689291784, 1897624297 } }, + // Matrix for nskip = 3 * 8 ^ 61: + { { 3547277681, 272901338, 2842437455 }, + { 1746901015, 3547277681, 4272690944 }, + { 2000451168, 1746901015, 417326012 } }, + // Matrix for nskip = 4 * 8 ^ 61: + { { 3656310954, 882924050, 2702189958 }, + { 3185020283, 3656310954, 1923190496 }, + { 2449669145, 3185020283, 4235849984 } }, + // Matrix for nskip = 5 * 8 ^ 61: + { { 3659342577, 1641516630, 2539516650 }, + { 2275633679, 3659342577, 167207049 }, + { 1798452176, 2275633679, 1651075902 } }, + // Matrix for nskip = 6 * 8 ^ 61: + { { 1932812117, 4060977130, 4129096120 }, + { 4247470915, 1932812117, 1398719693 }, + { 101546088, 4247470915, 103612315 } }, + // Matrix for nskip = 7 * 8 ^ 61: + { { 3420997084, 2682742609, 1335389027 }, + { 3883479775, 3420997084, 1501959755 }, + { 1647828648, 3883479775, 3801963100 } }, + // Matrix for nskip = 1 * 8 ^ 62: + { { 377232416, 1498446142, 4229103619 }, + { 3926377906, 377232416, 600268838 }, + { 511317726, 3926377906, 216160452 } }, + // Matrix for nskip = 2 * 8 ^ 62: + { { 1969399344, 3273966859, 4220943579 }, + { 3952111894, 1969399344, 575096961 }, + { 3815277103, 3952111894, 792177412 } }, + // Matrix for nskip = 3 * 8 ^ 62: + { { 1779275464, 2781126556, 2466688033 }, + { 1573179329, 1779275464, 2922475892 }, + { 3416534728, 1573179329, 2830179495 } }, + // Matrix for nskip = 4 * 8 ^ 62: + { { 2957238169, 1410010554, 1523740068 }, + { 3949237584, 2957238169, 74149658 }, + { 2564746147, 3949237584, 2557663578 } }, + // Matrix for nskip = 5 * 8 ^ 62: + { { 2132274169, 3311898863, 3609324462 }, + { 3719565953, 2132274169, 3678195166 }, + { 284265108, 3719565953, 4278461540 } }, + // Matrix for nskip = 6 * 8 ^ 62: + { { 2540404064, 675336157, 1264339488 }, + { 29787664, 2540404064, 3475225382 }, + { 591030331, 29787664, 1242712946 } }, + // Matrix for nskip = 7 * 8 ^ 62: + { { 3161673998, 796026877, 3360592842 }, + { 1326727008, 3161673998, 3697232048 }, + { 330692835, 1326727008, 3520194976 } }, + // Matrix for nskip = 1 * 8 ^ 63: + { { 3377318569, 1927835240, 2556102508 }, + { 3022040116, 3377318569, 2549406364 }, + { 2387074241, 3022040116, 1477293711 } }, + // Matrix for nskip = 2 * 8 ^ 63: + { { 257306870, 1748489735, 547809226 }, + { 3708493374, 257306870, 4183546362 }, + { 4435502, 3708493374, 1607696753 } }, + // Matrix for nskip = 3 * 8 ^ 63: + { { 2404623323, 4132820260, 1615062394 }, + { 1844725476, 2404623323, 570318859 }, + { 2839043606, 1844725476, 1375837008 } }, + // Matrix for nskip = 4 * 8 ^ 63: + { { 4076910933, 930542270, 3433720143 }, + { 675898567, 4076910933, 892406741 }, + { 5625977, 675898567, 2412946221 } }, + // Matrix for nskip = 5 * 8 ^ 63: + { { 3143857447, 1394551864, 4202002846 }, + { 973255696, 3143857447, 3968325674 }, + { 2327635494, 973255696, 1217794308 } }, + // Matrix for nskip = 6 * 8 ^ 63: + { { 2448094751, 2840824567, 1627957632 }, + { 1469753239, 2448094751, 4063581553 }, + { 3388871077, 1469753239, 3521935017 } }, + // Matrix for nskip = 7 * 8 ^ 63: + { { 1593620760, 1002861683, 2173731154 }, + { 3577868319, 1593620760, 39982755 }, + { 3566899985, 3577868319, 207847804 } }, + // Matrix for nskip = 1 * 8 ^ 64: + { { 2146755704, 2635194649, 1512299181 }, + { 3860948634, 2146755704, 3641948767 }, + { 3872596381, 3860948634, 1350534123 } }, + // Matrix for nskip = 2 * 8 ^ 64: + { { 2650974852, 2792146306, 1334806440 }, + { 3511147120, 2650974852, 3467471104 }, + { 2826608091, 3511147120, 3185213777 } }, + // Matrix for nskip = 3 * 8 ^ 64: + { { 4154591539, 929373784, 2614972987 }, + { 617404183, 4154591539, 1283899280 }, + { 637243382, 617404183, 1889016496 } }, + // Matrix for nskip = 4 * 8 ^ 64: + { { 1735625475, 2923145251, 885546512 }, + { 926645131, 1735625475, 2358202840 }, + { 3503695789, 926645131, 2511917556 } }, + // Matrix for nskip = 5 * 8 ^ 64: + { { 3169405477, 2071788237, 2197719325 }, + { 3454276765, 3169405477, 354513440 }, + { 3433509316, 3454276765, 3884018107 } }, + // Matrix for nskip = 6 * 8 ^ 64: + { { 154139786, 961249414, 3740576106 }, + { 1113118249, 154139786, 3880685356 }, + { 177260972, 1113118249, 1811433812 } }, + // Matrix for nskip = 7 * 8 ^ 64: + { { 2636917497, 3922853891, 3167851814 }, + { 911696899, 2636917497, 1449426394 }, + { 2845905825, 911696899, 1062710260 } } } +}; + +} // namespace mrg32k3a_impl +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_MRG32K3A_SKIP_AHEAD_MATRIX_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/philox4x32x10_impl.hpp b/include/oneapi/mkl/rng/device/detail/philox4x32x10_impl.hpp new file mode 100644 index 000000000..f061bb754 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/philox4x32x10_impl.hpp @@ -0,0 +1,552 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_PHILOX4X32X10_IMPL_HPP_ +#define _MKL_RNG_DEVICE_PHILOX4X32X10_IMPL_HPP_ + +#include // std::pair + +namespace oneapi::mkl::rng::device { + +template +class philox4x32x10; + +namespace detail { + +template +struct engine_state> { + std::uint32_t key[2]; + std::uint32_t counter[4]; + std::uint32_t part; + std::uint32_t result[4]; +}; + +namespace philox4x32x10_impl { + +static inline void add128(std::uint32_t* a, std::uint64_t b) { + std::uint64_t tmp = ((static_cast(a[1]) << 32) | a[0]); + + tmp += b; + + a[0] = static_cast(tmp); + a[1] = static_cast(tmp >> 32); + + if (tmp < b) { + tmp = ((static_cast(a[3]) << 32) | a[2]) + 1; + + a[2] = static_cast(tmp); + a[3] = static_cast(tmp >> 32); + } + return; +} + +static inline void add128_1(std::uint32_t* a) { + if (++a[0]) { + return; + } + if (++a[1]) { + return; + } + if (++a[2]) { + return; + } + ++a[3]; +} + +static inline std::pair mul_hilo_32(std::uint32_t a, + std::uint32_t b) { + std::uint64_t res_64 = static_cast(a) * static_cast(b); + return std::make_pair(static_cast(res_64), + static_cast(res_64 >> 32)); +} + +static inline void round(std::uint32_t* cnt, std::uint32_t* k) { + auto [L0, H0] = mul_hilo_32(0xD2511F53, cnt[0]); + auto [L1, H1] = mul_hilo_32(0xCD9E8D57, cnt[2]); + + cnt[0] = H1 ^ cnt[1] ^ k[0]; + cnt[1] = L1; + cnt[2] = H0 ^ cnt[3] ^ k[1]; + cnt[3] = L0; +} + +static inline void round_10(std::uint32_t* cnt, std::uint32_t* k) { + round(cnt, k); // 1 + // increasing keys with philox4x32x10 constants + k[0] += 0x9E3779B9; + k[1] += 0xBB67AE85; + round(cnt, k); // 2 + k[0] += 0x9E3779B9; + k[1] += 0xBB67AE85; + round(cnt, k); // 3 + k[0] += 0x9E3779B9; + k[1] += 0xBB67AE85; + round(cnt, k); // 4 + k[0] += 0x9E3779B9; + k[1] += 0xBB67AE85; + round(cnt, k); // 5 + k[0] += 0x9E3779B9; + k[1] += 0xBB67AE85; + round(cnt, k); // 6 + k[0] += 0x9E3779B9; + k[1] += 0xBB67AE85; + round(cnt, k); // 7 + k[0] += 0x9E3779B9; + k[1] += 0xBB67AE85; + round(cnt, k); // 8 + k[0] += 0x9E3779B9; + k[1] += 0xBB67AE85; + round(cnt, k); // 9 + k[0] += 0x9E3779B9; + k[1] += 0xBB67AE85; + round(cnt, k); // 10 +} + +template +static inline void skip_ahead(engine_state>& state, + std::uint64_t num_to_skip) { + std::uint64_t num_to_skip_tmp = num_to_skip; + std::uint64_t c_inc; + std::uint32_t counter[4]; + std::uint32_t key[2]; + std::uint64_t tail; + if (num_to_skip_tmp <= state.part) { + state.part -= num_to_skip_tmp; + } + else { + tail = num_to_skip % 4; + if ((tail == 0) && (state.part == 0)) { + add128(state.counter, num_to_skip / 4); + } + else { + num_to_skip_tmp = num_to_skip_tmp - state.part; + state.part = 0; + c_inc = (num_to_skip_tmp - 1) / 4; + state.part = (4 - num_to_skip_tmp % 4) % 4; + add128(state.counter, c_inc); + counter[0] = state.counter[0]; + counter[1] = state.counter[1]; + counter[2] = state.counter[2]; + counter[3] = state.counter[3]; + key[0] = state.key[0]; + key[1] = state.key[1]; + round_10(counter, key); + state.result[0] = counter[0]; + state.result[1] = counter[1]; + state.result[2] = counter[2]; + state.result[3] = counter[3]; + add128_1(state.counter); + } + } +} + +template +static inline void skip_ahead(engine_state>& state, + std::uint64_t n, const std::uint64_t* num_to_skip_ptr) { + constexpr std::uint64_t uint_max = 0xFFFFFFFFFFFFFFFF; + std::uint64_t post_buffer, pre_buffer; + std::int32_t num_elements = 0; + std::int32_t remained_counter; + std::uint64_t tmp_skip_array[3] = { 0, 0, 0 }; + + for (std::uint64_t i = 0; (i < 3) && (i < n); i++) { + tmp_skip_array[i] = num_to_skip_ptr[i]; + if (tmp_skip_array[i]) { + num_elements = i + 1; + } + } + + if (num_elements == 0) { + return; + } + if ((num_elements == 1) && (tmp_skip_array[0] <= state.part)) { + state.part -= static_cast(tmp_skip_array[0]); + return; + } + std::uint32_t counter[4]; + std::uint32_t key[2]; + + if ((tmp_skip_array[0] - state.part) <= tmp_skip_array[0]) { + tmp_skip_array[0] = tmp_skip_array[0] - state.part; + } + else if ((num_elements == 2) || (tmp_skip_array[1] - 1 < tmp_skip_array[1])) { + tmp_skip_array[1] = tmp_skip_array[1] - 1; + tmp_skip_array[0] = uint_max - state.part + tmp_skip_array[0]; + } + else { + tmp_skip_array[2] = tmp_skip_array[2] - 1; + tmp_skip_array[1] = uint_max - 1; + tmp_skip_array[0] = uint_max - state.part + tmp_skip_array[0]; + } + + state.part = 0; + + post_buffer = 0; + + remained_counter = static_cast(tmp_skip_array[0] % 4); + + for (int i = num_elements - 1; i >= 0; i--) { + pre_buffer = (tmp_skip_array[i] << 62); + tmp_skip_array[i] >>= 2; + tmp_skip_array[i] |= post_buffer; + post_buffer = pre_buffer; + } + + state.part = 4 - remained_counter; + + std::uint64_t counter64[] = { state.counter[1], state.counter[3] }; + counter64[0] = ((counter64[0] << 32ull) | state.counter[0]); + counter64[1] = ((counter64[1] << 32ull) | state.counter[2]); + + counter64[0] += tmp_skip_array[0]; + + if (counter64[0] < tmp_skip_array[0]) { + counter64[1]++; + } + + counter64[1] += tmp_skip_array[1]; + + counter[0] = static_cast(counter64[0]); + counter[1] = static_cast(counter64[0] >> 32); + counter[2] = static_cast(counter64[1]); + counter[3] = static_cast(counter64[1] >> 32); + + key[0] = state.key[0]; + key[1] = state.key[1]; + + round_10(counter, key); + + state.result[0] = counter[0]; + state.result[1] = counter[1]; + state.result[2] = counter[2]; + state.result[3] = counter[3]; + + counter64[0]++; + + if (counter64[0] < 1) { + counter64[1]++; + } + + state.counter[0] = static_cast(counter64[0]); + state.counter[1] = static_cast(counter64[0] >> 32); + state.counter[2] = static_cast(counter64[1]); + state.counter[3] = static_cast(counter64[1] >> 32); +} + +template +static inline void init(engine_state>& state, + std::uint64_t n, const std::uint64_t* seed_ptr, std::uint64_t offset) { + state.key[0] = static_cast(seed_ptr[0]); + state.key[1] = static_cast(seed_ptr[0] >> 32); + + state.counter[0] = (n >= 2 ? static_cast(seed_ptr[1]) : 0); + state.counter[1] = (n >= 2 ? static_cast(seed_ptr[1] >> 32) : 0); + + state.counter[2] = (n >= 3 ? static_cast(seed_ptr[2]) : 0); + state.counter[3] = (n >= 3 ? static_cast(seed_ptr[2] >> 32) : 0); + + state.part = 0; + state.result[0] = 0; + state.result[1] = 0; + state.result[2] = 0; + state.result[3] = 0; + skip_ahead(state, offset); +} + +template +static inline void init(engine_state>& state, + std::uint64_t n, const std::uint64_t* seed_ptr, std::uint64_t n_offset, + const std::uint64_t* offset_ptr) { + state.key[0] = static_cast(seed_ptr[0]); + state.key[1] = static_cast(seed_ptr[0] >> 32); + + state.counter[0] = (n >= 2 ? static_cast(seed_ptr[1]) : 0); + state.counter[1] = (n >= 2 ? static_cast(seed_ptr[1] >> 32) : 0); + + state.counter[2] = (n >= 3 ? static_cast(seed_ptr[2]) : 0); + state.counter[3] = (n >= 3 ? static_cast(seed_ptr[2] >> 32) : 0); + + state.part = 0; + state.result[0] = 0; + state.result[1] = 0; + state.result[2] = 0; + state.result[3] = 0; + skip_ahead(state, n_offset, offset_ptr); +} + +// for VecSize > 4 +template +__attribute__((always_inline)) static inline sycl::vec generate_full( + engine_state>& state) { + const std::int32_t num_elements = VecSize; + sycl::vec res; + + std::uint32_t counter[4]; + + int i = 0; + int part = (int)state.part; + while (part && (i < num_elements)) { + res[i++] = state.result[3 - (--part)]; + } + if (i == num_elements) { + skip_ahead(state, num_elements); + return res; + } + + counter[0] = state.counter[0]; + counter[1] = state.counter[1]; + counter[2] = state.counter[2]; + counter[3] = state.counter[3]; + + std::uint32_t cntTmp[4]; + std::uint32_t keyTmp[2]; + for (; i < num_elements; i += 4) { + cntTmp[0] = counter[0]; + cntTmp[1] = counter[1]; + cntTmp[2] = counter[2]; + cntTmp[3] = counter[3]; + + keyTmp[0] = state.key[0]; + keyTmp[1] = state.key[1]; + + round_10(cntTmp, keyTmp); + + if (i + 4 <= num_elements) { + for (int j = 0; j < 4; j++) { + res[i + j] = cntTmp[j]; + } + add128_1(counter); + } + else { + // here if last iteration + for (int j = 0; i < num_elements; i++, j++) { + res[i] = cntTmp[j]; + } + } + } + skip_ahead(state, num_elements); + return res; +} + +// for VecSize <= 4 +template +__attribute__((always_inline)) static inline sycl::vec generate_small( + engine_state>& state) { + const std::int32_t num_elements = VecSize; + sycl::vec res; + + std::uint32_t counter[4]; + std::uint32_t key[2]; + + int i = 0; + int part = (int)state.part; + while (part && (i < num_elements)) { + res[i++] = state.result[3 - (--part)]; + } + if (i == num_elements) { + skip_ahead(state, num_elements); + return res; + } + + counter[0] = state.counter[0]; + counter[1] = state.counter[1]; + counter[2] = state.counter[2]; + counter[3] = state.counter[3]; + key[0] = state.key[0]; + key[1] = state.key[1]; + + round_10(counter, key); + + for (int j = 0; i < num_elements; i++, j++) { + res[i] = counter[j]; + } + + skip_ahead(state, num_elements); + return res; +} + +template +__attribute__((always_inline)) static inline std::uint32_t generate_single( + engine_state>& state) { + std::uint32_t res; + + std::uint32_t counter[4]; + std::uint32_t key[2]; + + std::int32_t part = static_cast(state.part); + if (part != 0) { + res = state.result[3 - (--part)]; + skip_ahead(state, 1); + return res; + } + counter[0] = state.counter[0]; + counter[1] = state.counter[1]; + counter[2] = state.counter[2]; + counter[3] = state.counter[3]; + key[0] = state.key[0]; + key[1] = state.key[1]; + + round_10(counter, key); + + res = counter[0]; + + skip_ahead(state, 1); + return res; +} + +} // namespace philox4x32x10_impl + +template +class engine_base> { +protected: + engine_base(std::uint64_t seed, std::uint64_t offset = 0) { + philox4x32x10_impl::init(this->state_, 1, &seed, offset); + } + + engine_base(std::uint64_t n, const std::uint64_t* seed, std::uint64_t offset = 0) { + philox4x32x10_impl::init(this->state_, n, seed, offset); + } + + engine_base(std::uint64_t seed, std::uint64_t n_offset, const std::uint64_t* offset_ptr) { + philox4x32x10_impl::init(this->state_, 1, &seed, n_offset, offset_ptr); + } + + engine_base(std::uint64_t n, const std::uint64_t* seed, std::uint64_t n_offset, + const std::uint64_t* offset_ptr) { + philox4x32x10_impl::init(this->state_, n, seed, n_offset, offset_ptr); + } + + template + __attribute__((always_inline)) inline auto generate(RealType a, RealType b) -> + typename std::conditional>::type { + sycl::vec res; + sycl::vec res_uint; + RealType a1; + RealType c1; + + c1 = (b - a) / (static_cast((std::numeric_limits::max)()) + 1); + a1 = (b + a) / static_cast(2.0); + + if constexpr (VecSize > 4) { + res_uint = philox4x32x10_impl::generate_full(this->state_); + } + else { + res_uint = philox4x32x10_impl::generate_small(this->state_); + } + for (int i = 0; i < VecSize; i++) { + res[i] = static_cast(static_cast(res_uint[i])) * c1 + a1; + } + return res; + } + + __attribute__((always_inline)) inline auto generate() -> + typename std::conditional>::type { + if constexpr (VecSize > 4) { + return philox4x32x10_impl::generate_full(this->state_); + } + return philox4x32x10_impl::generate_small(this->state_); + } + + template + __attribute__((always_inline)) inline auto generate_uniform_bits() -> + typename std::conditional>::type { + if constexpr (std::is_same::value) { + return generate(); + } + else { + auto uni_res1 = generate(); + auto uni_res2 = generate(); + + if constexpr (VecSize == 1) { + return (static_cast(uni_res2) << 32) + uni_res1; + } + else { + sycl::vec vec_out; + + if constexpr (VecSize != 3) { + for (int i = 0; i < VecSize / 2; i++) { + vec_out[i] = (static_cast(uni_res1[2 * i + 1]) << 32) + + uni_res1[2 * i]; + vec_out[i + VecSize / 2] = + (static_cast(uni_res2[2 * i + 1]) << 32) + + uni_res2[2 * i]; + } + } + else { + vec_out[0] = (static_cast(uni_res1[1]) << 32) + uni_res1[0]; + vec_out[1] = (static_cast(uni_res2[0]) << 32) + uni_res1[2]; + vec_out[2] = (static_cast(uni_res2[2]) << 32) + uni_res2[1]; + } + + return vec_out; + } + } + } + + template + RealType generate_single(RealType a, RealType b) { + RealType res; + std::uint32_t res_uint; + RealType a1; + RealType c1; + + c1 = (b - a) / (static_cast((std::numeric_limits::max)()) + 1); + a1 = (b + a) / static_cast(2.0); + + res_uint = philox4x32x10_impl::generate_single(this->state_); + + res = static_cast(static_cast(res_uint)) * c1 + a1; + + return res; + } + + __attribute__((always_inline)) inline std::uint32_t generate_single() { + return philox4x32x10_impl::generate_single(this->state_); + } + + template + __attribute__((always_inline)) inline auto generate_single_uniform_bits() { + if constexpr (std::is_same::value) { + return philox4x32x10_impl::generate_single(this->state_); + } + else { + auto uni_res1 = philox4x32x10_impl::generate_single(this->state_); + auto uni_res2 = philox4x32x10_impl::generate_single(this->state_); + + return (static_cast(uni_res2) << 32) + uni_res1; + } + } + + void skip_ahead(std::uint64_t num_to_skip) { + detail::philox4x32x10_impl::skip_ahead(this->state_, num_to_skip); + } + + void skip_ahead(std::initializer_list num_to_skip) { + detail::philox4x32x10_impl::skip_ahead(this->state_, num_to_skip.size(), + num_to_skip.begin()); + } + + engine_state> state_; +}; + +} // namespace detail +} // namespace oneapi::mkl::rng::device + +#endif // _MKL_RNG_DEVICE_PHILOX4X32X10_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/poisson_impl.hpp b/include/oneapi/mkl/rng/device/detail/poisson_impl.hpp new file mode 100644 index 000000000..9fa9b26ec --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/poisson_impl.hpp @@ -0,0 +1,355 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_POISSON_IMPL_HPP_ +#define _MKL_RNG_DEVICE_POISSON_IMPL_HPP_ + +#include + +namespace oneapi::mkl::rng::device::detail { + +// Implementation of Poisson distribution uses 3 methods depending on lambda parameter: +// - table-lookup method [1] for small lambdas (lambda < 60) +// - Devroye's method [2] for medium lambdas (60 <= lambda < 1000) +// - Gaussian approximation [1] for huge lambdas (lambda >= 1000) +// +// References: +// [1] Michael B. Giles // Algorithm 955: approximation of the inverse Poisson cumulative +// distribution function +// [2] Devroye, L. Non-Uniform Random Variates Generation. Springer-Verlag, +// New York, 1986, Ch. X, Sects. 3.3 & 3.4 + Errata + +#define RNG_POISSON_LAMBDA_HUGE_BOUND 1000.0 +#define RNG_POISSON_LAMBDA_LOW_BOUND 60.0 +#define RNG_POISSON_N_PRECOMPUTED_CDF 32 + +struct poisson_parameters { + void set_lambda(double lambda) { + if (lambda >= RNG_POISSON_LAMBDA_HUGE_BOUND) { + sqrt_lambda_ = sycl::sqrt(lambda); + } + else if (lambda >= RNG_POISSON_LAMBDA_LOW_BOUND) { + floored_lambda_ = sycl::floor(lambda); + log_lambda_ = sycl::log(lambda); + lgamma_floored_lambda_ = sycl::lgamma(floored_lambda_ + 1.0); + sqrt_floored_lambda_ = sycl::sqrt(floored_lambda_); + dx_ = sycl::sqrt(2.0 * floored_lambda_ * sycl::log(32.0 * floored_lambda_ / pi_4_)); + delta_ = sycl::round((sycl::max)(6.0, (sycl::min)(floored_lambda_, dx_))); + dpdfl_ = delta_ + 2.0 * floored_lambda_; + sqrt_half_dpdfl_ = sycl::sqrt(dpdfl_ / 2.0); + inv_dpdfl_ = 1.0 / dpdfl_; + c2_add_coeff_ = sycl::sqrt(pi_4_ * dpdfl_) * sycl::exp(inv_dpdfl_); + c_add_coeff_ = + 2.0 * dpdfl_ * sycl::exp(-delta_ * inv_dpdfl_ * (1.0 + delta_ / 2.0)) / delta_; + c1_ = sqrt_floored_lambda_ * spi_2_; + c2_ = c2_add_coeff_ + c1_; + c3_ = c2_ + 1.0; + c4_ = c2_ + 2.0; + c5_ = c4_ + exp_one_by_78; + c_ = c5_ + c_add_coeff_; + } + else { + prob[0] = sycl::exp(-lambda); + double tmp = prob[0]; + for (int i = 1; i < RNG_POISSON_N_PRECOMPUTED_CDF; ++i) { + tmp *= lambda / (double)i; + prob[i] = prob[i - 1] + tmp; + } + } + } + + poisson_parameters& operator=(const poisson_parameters& other) { + if (this == &other) { + return *this; + } + for (int i = 0; i < RNG_POISSON_N_PRECOMPUTED_CDF; i++) { + prob[i] = other.prob[i]; + } + floored_lambda_ = other.floored_lambda_; + log_lambda_ = other.log_lambda_; + lgamma_floored_lambda_ = other.lgamma_floored_lambda_; + sqrt_lambda_ = other.sqrt_lambda_; + sqrt_floored_lambda_ = other.sqrt_floored_lambda_; + dx_ = other.dx_; + delta_ = other.delta_; + dpdfl_ = other.dpdfl_; + sqrt_half_dpdfl_ = other.sqrt_half_dpdfl_; + inv_dpdfl_ = other.inv_dpdfl_; + c2_add_coeff_ = other.c2_add_coeff_; + c_add_coeff_ = other.c_add_coeff_; + c1_ = other.c1_; + c2_ = other.c2_; + c3_ = other.c3_; + c4_ = other.c4_; + c5_ = other.c5_; + c_ = other.c_; + return *this; + } + double prob[RNG_POISSON_N_PRECOMPUTED_CDF]; + double floored_lambda_ = 0.0; + double log_lambda_ = 0.0; + double lgamma_floored_lambda_ = 0.0; + double sqrt_lambda_ = 0.0; + double sqrt_floored_lambda_ = 0.0; + double dx_ = 0.0; + double delta_ = 0.0; + double dpdfl_ = 0.0; + double sqrt_half_dpdfl_ = 0.0; + double inv_dpdfl_ = 0.0; + double c2_add_coeff_ = 0.0; + double c_add_coeff_ = 0.0; + double c1_ = 0.0; + double c2_ = 0.0; + double c3_ = 0.0; + double c4_ = 0.0; + double c5_ = 0.0; + double c_ = 0.0; + const double exp_one_by_78 = 1.0129030479320018583185514777512983L; + const double pi_4_ = 0.7853981633974483096156608458198757L; + const double spi_2_ = 1.2533141373155002512078826424055226L; +}; + +template +class distribution_base> { +public: + struct param_type { + param_type(double lambda) : lambda_(lambda) {} + double lambda_; + }; + + distribution_base(double lambda) : lambda_(lambda) { +#ifndef __SYCL_DEVICE_ONLY__ + if (lambda_ <= 0.0) { + throw oneapi::mkl::invalid_argument("rng", "poisson", "lambda <= 0"); + } +#endif + params_.set_lambda(lambda_); + } + + double lambda() const { + return lambda_; + } + + param_type param() const { + return param_type(lambda_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.lambda_ <= 0.0) { + throw oneapi::mkl::invalid_argument("rng", "poisson", "lambda <= 0"); + } +#endif + lambda_ = pt.lambda_; + params_.set_lambda(lambda_); + } + +protected: + IntType get_one_num_small_lambdas(double uniform_var) { + IntType res = 0; + if (uniform_var < params_.prob[0]) { + return res; + } + else { + for (res = 1; res < RNG_POISSON_N_PRECOMPUTED_CDF; ++res) { + if (uniform_var < params_.prob[res]) { + return res; + } + } + // in case uniform_var is still bigger than CDF[31] compute additional CDF coefficients + double prob_less_than_k = params_.prob[--res]; + double prob_that_k = prob_less_than_k - params_.prob[res - 1]; + do { + prob_that_k *= lambda_ / (double)(res++ + 1); + prob_less_than_k += prob_that_k; + } while (uniform_var >= prob_less_than_k); + + return res; + } + } + template + IntType get_one_num_med_lambdas(EngineType& engine) { + const double rounding_coeff = (1.0 - std::numeric_limits::epsilon()) / 2.0; + const double max_inttype_val = (std::numeric_limits::max)() + rounding_coeff; + double res_; + bool rejection_flag = true; + do { + const double uniform_var = params_.c_ * engine.generate_single(0.0, 1.0); + const double exponential_var = exponential_.generate_single(engine); + double w = 0.0; + if (uniform_var <= params_.c1_) { + const double gaussian_var = gaussian_.generate_single(engine); + const double y = -sycl::fabs(gaussian_var) * params_.sqrt_floored_lambda_ - 1.0; + res_ = sycl::floor(y); + w = -gaussian_var * gaussian_var / 2.0; + if (res_ < -params_.floored_lambda_) + continue; + } + else if (uniform_var <= params_.c2_) { + const double gaussian_var = gaussian_.generate_single(engine); + const double y = 1.0 + sycl::fabs(gaussian_var) * params_.sqrt_half_dpdfl_; + res_ = sycl::ceil(y); + w = y * (2.0 - y) * params_.inv_dpdfl_; + if (res_ > params_.delta_) + continue; + } + else if (uniform_var <= params_.c3_) + res_ = -1.0; + else if (uniform_var <= params_.c4_) + res_ = 0.0; + else if (uniform_var <= params_.c5_) + res_ = 1.0; + else { + const double exponential_var_1 = exponential_.generate_single(engine); + const double y = + params_.delta_ + exponential_var_1 * 2.0 * params_.dpdfl_ / params_.delta_; + res_ = sycl::ceil(y); + w = -params_.delta_ * params_.inv_dpdfl_ * (1.0 + y / 2.0); + } + + rejection_flag = ((w - exponential_var - res_ * params_.log_lambda_) > + (params_.lgamma_floored_lambda_ - + sycl::lgamma(res_ + params_.floored_lambda_ + 1.0))); + + rejection_flag |= (res_ + params_.floored_lambda_) >= max_inttype_val; + + } while (rejection_flag); + + return ((IntType)(res_ + params_.floored_lambda_ + rounding_coeff)); + } + + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + using OutType = typename std::conditional>::type; + OutType res; + if constexpr (EngineType::vec_size == 1) { + res = 0; + if (lambda_ < RNG_POISSON_LAMBDA_LOW_BOUND) { + double uniform_var = engine.generate(0.0, 1.0); + return get_one_num_small_lambdas(uniform_var); + } + else if (lambda_ < RNG_POISSON_LAMBDA_HUGE_BOUND) { + const double rounding_coeff = (1.0 - std::numeric_limits::epsilon()) / 2.0; + const double max_inttype_val = + (std::numeric_limits::max)() + rounding_coeff; + double res_; + bool rejection_flag = true; + do { + const double uniform_var = params_.c_ * engine.generate(0.0, 1.0); + const double exponential_var = exponential_.generate(engine); + double w = 0.0; + if (uniform_var <= params_.c1_) { + const double gaussian_var = gaussian_.generate(engine); + const double y = + -sycl::fabs(gaussian_var) * params_.sqrt_floored_lambda_ - 1.0; + res_ = sycl::floor(y); + w = -gaussian_var * gaussian_var / 2.0; + if (res_ < -params_.floored_lambda_) + continue; + } + else if (uniform_var <= params_.c2_) { + const double gaussian_var = gaussian_.generate(engine); + const double y = 1.0 + sycl::fabs(gaussian_var) * params_.sqrt_half_dpdfl_; + res_ = sycl::ceil(y); + w = y * (2.0 - y) * params_.inv_dpdfl_; + if (res_ > params_.delta_) + continue; + } + else if (uniform_var <= params_.c3_) + res_ = -1.0; + else if (uniform_var <= params_.c4_) + res_ = 0.0; + else if (uniform_var <= params_.c5_) + res_ = 1.0; + else { + const double exponential_var_1 = exponential_.generate(engine); + const double y = params_.delta_ + + exponential_var_1 * 2.0 * params_.dpdfl_ / params_.delta_; + res_ = sycl::ceil(y); + w = -params_.delta_ * params_.inv_dpdfl_ * (1.0 + y / 2.0); + } + + rejection_flag = ((w - exponential_var - res_ * params_.log_lambda_) > + (params_.lgamma_floored_lambda_ - + sycl::lgamma(res_ + params_.floored_lambda_ + 1.0))); + + rejection_flag |= (res_ + params_.floored_lambda_) >= max_inttype_val; + + } while (rejection_flag); + + return ((IntType)(res_ + params_.floored_lambda_ + rounding_coeff)); + } + else { + res = static_cast(lambda_ + + params_.sqrt_lambda_ * gaussian_.generate(engine)); + } + } + else { + if (lambda_ < RNG_POISSON_LAMBDA_LOW_BOUND) { + auto uniform_var = engine.generate(0.0, 1.0); + for (int i = 0; i < EngineType::vec_size; ++i) { + res[i] = get_one_num_small_lambdas(uniform_var[i]); + } + return res; + } + else if (lambda_ < RNG_POISSON_LAMBDA_HUGE_BOUND) { + for (int i = 0; i < EngineType::vec_size; ++i) { + res[i] = get_one_num_med_lambdas(engine); + } + return res; + } + else { + sycl::vec res_fp = + lambda_ + params_.sqrt_lambda_ * gaussian_.generate(engine); + res_fp = sycl::floor(res_fp); + res = res_fp.template convert(); + } + } + return res; + } + + template + IntType generate_single(EngineType& engine) { + IntType res = 0; + if (lambda_ < RNG_POISSON_LAMBDA_LOW_BOUND) { + double uniform_var = engine.generate_single(0.0, 1.0); + return get_one_num_small_lambdas(uniform_var); + } + else if (lambda_ < RNG_POISSON_LAMBDA_HUGE_BOUND) { + return get_one_num_med_lambdas(engine); + } + else { + res = static_cast(lambda_ + + params_.sqrt_lambda_ * gaussian_.generate_single(engine)); + } + return res; + } + + distribution_base> gaussian_ = { 0.0, 1.0 }; + distribution_base> exponential_ = { 0.0, 1.0 }; + poisson_parameters params_; + double lambda_; +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_POISSON_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/uniform_bits_impl.hpp b/include/oneapi/mkl/rng/device/detail/uniform_bits_impl.hpp new file mode 100644 index 000000000..cd3cd2eed --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/uniform_bits_impl.hpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_UNIFORM_BITS_IMPL_HPP_ +#define _MKL_RNG_DEVICE_UNIFORM_BITS_IMPL_HPP_ + +#include "engine_base.hpp" + +namespace oneapi::mkl::rng::device::detail { + +template +class distribution_base> { +protected: + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + static_assert(std::is_same>::value || + std::is_same>::value, + "oneMKL: uniform_bits works only with philox4x32x10/mcg59 engines"); + return engine.template generate_uniform_bits(); + } + + template + UIntType generate_single(EngineType& engine) { + static_assert(std::is_same>::value || + std::is_same>::value, + "oneMKL: uniform_bits works only with philox4x32x10/mcg59 engines"); + return engine.template generate_single_uniform_bits(); + } +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_UNIFORM_BITS_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp new file mode 100644 index 000000000..bdd7f79d7 --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/uniform_impl.hpp @@ -0,0 +1,131 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_UNIFORM_IMPL_HPP_ +#define _MKL_RNG_DEVICE_UNIFORM_IMPL_HPP_ + +namespace oneapi::mkl::rng::device::detail { + +template +class distribution_base> { +public: + struct param_type { + param_type(Type a, Type b) : a_(a), b_(b) {} + Type a_; + Type b_; + }; + + distribution_base(Type a, Type b) : a_(a), b_(b) { +#ifndef __SYCL_DEVICE_ONLY__ + if (a >= b) { + throw oneapi::mkl::invalid_argument("rng", "uniform", "a >= b"); + } +#endif + } + + Type a() const { + return a_; + } + + Type b() const { + return b_; + } + + param_type param() const { + return param_type(a_, b_); + } + + void param(const param_type& pt) { +#ifndef __SYCL_DEVICE_ONLY__ + if (pt.a_ >= pt.b_) { + throw oneapi::mkl::invalid_argument("rng", "uniform", "a >= b"); + } +#endif + a_ = pt.a_; + b_ = pt.b_; + } + +protected: + template + auto generate(EngineType& engine) -> + typename std::conditional>::type { + using OutType = typename std::conditional>::type; + using FpType = + typename std::conditional::value, double, + float>::type; + OutType res; + if constexpr (std::is_integral::value) { + if constexpr (EngineType::vec_size == 1) { + FpType res_fp = engine.generate(static_cast(a_), static_cast(b_)); + res_fp = sycl::floor(res_fp); + res = static_cast(res_fp); + return res; + } + else { + sycl::vec res_fp; + res_fp = engine.generate(static_cast(a_), static_cast(b_)); + res_fp = sycl::floor(res_fp); + res = res_fp.template convert(); + return res; + } + } + else { + res = engine.generate(a_, b_); + if constexpr (std::is_same::value) { + res = sycl::fmax(res, OutType{ a_ }); + res = sycl::fmin(res, OutType{ b_ }); + } + } + + return res; + } + + template + Type generate_single(EngineType& engine) { + using FpType = + typename std::conditional::value, double, + float>::type; + Type res; + if constexpr (std::is_integral::value) { + FpType res_fp = + engine.generate_single(static_cast(a_), static_cast(b_)); + res_fp = sycl::floor(res_fp); + res = static_cast(res_fp); + return res; + } + else { + res = engine.generate_single(a_, b_); + if constexpr (std::is_same::value) { + res = sycl::fmax(res, a_); + res = sycl::fmin(res, b_); + } + } + + return res; + } + + Type a_; + Type b_; +}; + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_UNIFORM_IMPL_HPP_ diff --git a/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp b/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp new file mode 100644 index 000000000..ec070c92c --- /dev/null +++ b/include/oneapi/mkl/rng/device/detail/vm_wrappers.hpp @@ -0,0 +1,61 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_VM_WRAPPERS_HPP_ +#define _MKL_RNG_DEVICE_VM_WRAPPERS_HPP_ + +#include + +namespace oneapi::mkl::rng::device::detail { + +template +static inline DataType sqrt_wrapper(DataType a) { + return sycl::sqrt(a); +} + +template +static inline DataType sinpi_wrapper(DataType a) { + return sycl::sinpi(a); +} + +template +static inline DataType cospi_wrapper(DataType a) { + return sycl::cospi(a); +} + +template +static inline DataType sincospi_wrapper(DataType a, DataType& b) { + b = sycl::cospi(a); + return sycl::sinpi(a); +} + +template +static inline DataType ln_wrapper(DataType a) { + if (a == DataType(0)) { + if constexpr (std::is_same_v) + return -0x1.74385446D71C3P+9; // ln(0.494065e-323) = -744.440072 + else + return -0x1.9D1DA0P+6f; // ln(0.14012984e-44) = -103.278929 + } + return sycl::log(a); +} + +} // namespace oneapi::mkl::rng::device::detail + +#endif // _MKL_RNG_DEVICE_VM_WRAPPERS_HPP_ diff --git a/include/oneapi/mkl/rng/device/distributions.hpp b/include/oneapi/mkl/rng/device/distributions.hpp new file mode 100644 index 000000000..21739f7f2 --- /dev/null +++ b/include/oneapi/mkl/rng/device/distributions.hpp @@ -0,0 +1,480 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_DISTRIBUTIONS_HPP_ +#define _MKL_RNG_DEVICE_DISTRIBUTIONS_HPP_ + +#include + +#include "oneapi/mkl/rng/device/detail/distribution_base.hpp" +#include "oneapi/mkl/rng/device/functions.hpp" + +namespace oneapi::mkl::rng::device { + +// CONTINUOUS AND DISCRETE RANDOM NUMBER DISTRIBUTIONS + +// Class template oneapi::mkl::rng::device::uniform +// +// Represents continuous and discrete uniform random number distribution +// +// Supported types: +// float +// double +// std::int32_t +// std::uint32_t +// +// Supported methods: +// oneapi::mkl::rng::device::uniform_method::standard +// oneapi::mkl::rng::device::uniform_method::accurate +// +// Input arguments: +// a - left bound. 0.0 by default +// b - right bound. 1.0 by default (for std::(u)int32_t std::numeric_limits::max() +// is used for accurate method and 2^23 is used for standard method) +// +// Note: using (un)signed integer uniform distribution with uniform_method::standard method may +// cause incorrect statistics of the produced random numbers (due to rounding error) if +// (abs(b - a) > 2^23) || (abs(b) > 2^23) || (abs(a) > 2^23) +// Please use uniform_method::accurate method instead +// +template +class uniform : detail::distribution_base> { +public: + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/uniform: method is incorrect"); + + static_assert(std::is_same::value || std::is_same::value || + std::is_same::value || + std::is_same::value, + "oneMKL: rng/uniform: type is not supported"); + + using method_type = Method; + using result_type = Type; + using param_type = typename detail::distribution_base>::param_type; + + uniform() + : detail::distribution_base>( + static_cast(0.0), + std::is_integral::value + ? (std::is_same::value + ? (1 << 23) + : (std::numeric_limits::max)()) + : static_cast(1.0)) {} + + explicit uniform(Type a, Type b) : detail::distribution_base>(a, b) {} + explicit uniform(const param_type& pt) + : detail::distribution_base>(pt.a_, pt.b_) {} + + Type a() const { + return detail::distribution_base>::a(); + } + + Type b() const { + return detail::distribution_base>::b(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + +private: + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +// Class template oneapi::mkl::rng::device::gaussian +// +// Represents continuous normal random number distribution +// +// Supported types: +// float +// double +// +// Supported methods: +// oneapi::mkl::rng::device::gaussian_method::box_muller2 +// oneapi::mkl::rng::device::gaussian_method::icdf +// +// Input arguments: +// mean - mean. 0 by default +// stddev - standard deviation. 1.0 by default +// +template +class gaussian : detail::distribution_base> { +public: + static_assert(std::is_same::value +#if MKL_RNG_USE_BINARY_CODE + || std::is_same::value +#endif + , + "oneMKL: rng/gaussian: method is incorrect"); +#if !MKL_RNG_USE_BINARY_CODE + static_assert(!std::is_same::value, "icdf method not supported"); +#endif + static_assert(std::is_same::value || std::is_same::value, + "oneMKL: rng/gaussian: type is not supported"); + + using method_type = Method; + using result_type = RealType; + using param_type = typename detail::distribution_base>::param_type; + + gaussian() + : detail::distribution_base>(static_cast(0.0), + static_cast(1.0)) {} + + explicit gaussian(RealType mean, RealType stddev) + : detail::distribution_base>(mean, stddev) {} + explicit gaussian(const param_type& pt) + : detail::distribution_base>(pt.mean_, pt.stddev_) {} + + RealType mean() const { + return detail::distribution_base>::mean(); + } + + RealType stddev() const { + return detail::distribution_base>::stddev(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +// Class template oneapi::mkl::rng::device::lognormal +// +// Represents continuous lognormal random number distribution +// +// Supported types: +// float +// double +// +// Supported methods: +// oneapi::mkl::rng::device::lognormal_method::box_muller2 +// +// Input arguments: +// m - mean of the subject normal distribution. 0.0 by default +// s - standard deviation of the subject normal distribution. 1.0 by default +// displ - displacement. 0.0 by default +// scale - scalefactor. 1.0 by default +// +template +class lognormal : detail::distribution_base> { +public: + static_assert(std::is_same::value, + "oneMKL: rng/lognormal: method is incorrect"); + + static_assert(std::is_same::value || std::is_same::value, + "oneMKL: rng/lognormal: type is not supported"); + + using method_type = Method; + using result_type = RealType; + using param_type = typename detail::distribution_base>::param_type; + + lognormal() + : detail::distribution_base>( + static_cast(0.0), static_cast(1.0), + static_cast(0.0), static_cast(1.0)) {} + + explicit lognormal(RealType m, RealType s, RealType displ = static_cast(0.0), + RealType scale = static_cast(1.0)) + : detail::distribution_base>(m, s, displ, scale) {} + explicit lognormal(const param_type& pt) + : detail::distribution_base>(pt.m_, pt.s_, pt.displ_, + pt.scale_) {} + + RealType m() const { + return detail::distribution_base>::m(); + } + + RealType s() const { + return detail::distribution_base>::s(); + } + + RealType displ() const { + return detail::distribution_base>::displ(); + } + + RealType scale() const { + return detail::distribution_base>::scale(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +// Class template oneapi::mkl::rng::device::uniform_bits +// +// Represents discrete uniform bits random number distribution +// +// Supported types: +// std::uint32_t +// std::uint64_t +// +template +class uniform_bits : detail::distribution_base> { +public: + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/uniform_bits: type is not supported"); + using result_type = UIntType; + +private: + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +// Class template oneapi::mkl::rng::device::bits +// +// Represents bits of underlying random number engine +// +// Supported types: +// std::uint32_t for philox4x32x10, mrg32k3a and mcg31m1 +// std::uint64_t for mcg59 only +// +template +class bits : detail::distribution_base> { +public: + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/bits: type is not supported"); + using result_type = UIntType; + +private: + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +// Class template oneapi::mkl::rng::device::exponential +// +// Represents continuous exponential random number distribution +// +// Supported types: +// float +// double +// +// Supported methods: +// oneapi::mkl::rng::device::exponential_method::icdf +// oneapi::mkl::rng::device::exponential_method::icdf_accurate +// +// Input arguments: +// displ - displacement. 0.0 by default +// scale - scalefactor. 1.0 by default +// +template +class exponential : detail::distribution_base> { +public: + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/exponential: method is incorrect"); + + static_assert(std::is_same::value || std::is_same::value, + "oneMKL: rng/exponential: type is not supported"); + + using method_type = Method; + using result_type = RealType; + using param_type = + typename detail::distribution_base>::param_type; + + exponential() + : detail::distribution_base>( + static_cast(0.0), static_cast(1.0)) {} + + explicit exponential(RealType a, RealType beta) + : detail::distribution_base>(a, beta) {} + + explicit exponential(const param_type& pt) + : detail::distribution_base>(pt.a_, pt.beta_) {} + + RealType a() const { + return detail::distribution_base>::a(); + } + + RealType beta() const { + return detail::distribution_base>::beta(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +// Class template oneapi::mkl::rng::device::poisson +// +// Represents discrete poisson random number distribution +// +// Supported types: +// std::int32_t +// std::uint32_t +// +// Supported methods: +// oneapi::mkl::rng::device::poisson_method::devroye +// +// Input arguments: +// lambda - mean value. 1.0 by default +// +template +class poisson : detail::distribution_base> { +public: + static_assert(std::is_same::value, + "oneMKL: rng/poisson: method is incorrect"); + + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/poisson: type is not supported"); + + using method_type = Method; + using result_type = IntType; + using param_type = typename detail::distribution_base>::param_type; + + poisson() : detail::distribution_base>(0.5) {} + + explicit poisson(double lambda) : detail::distribution_base>(lambda) {} + explicit poisson(const param_type& pt) + : detail::distribution_base>(pt.lambda_) {} + + double lambda() const { + return detail::distribution_base>::lambda(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +// Class template oneapi::mkl::rng::device::bernoulli +// +// Represents discrete Bernoulli random number distribution +// +// Supported types: +// std::uint32_t +// std::int32_t +// +// Supported methods: +// oneapi::mkl::rng::bernoulli_method::icdf; +// +// Input arguments: +// p - success probablity of a trial. 0.5 by default +// +template +class bernoulli : detail::distribution_base> { +public: + static_assert(std::is_same::value, + "oneMKL: rng/bernoulli: method is incorrect"); + + static_assert(std::is_same::value || + std::is_same::value, + "oneMKL: rng/bernoulli: type is not supported"); + + using method_type = Method; + using result_type = IntType; + using param_type = typename detail::distribution_base>::param_type; + + bernoulli() : detail::distribution_base>(0.5f) {} + + explicit bernoulli(float p) : detail::distribution_base>(p) {} + explicit bernoulli(const param_type& pt) + : detail::distribution_base>(pt.p_) {} + + float p() const { + return detail::distribution_base>::p(); + } + + param_type param() const { + return detail::distribution_base>::param(); + } + + void param(const param_type& pt) { + detail::distribution_base>::param(pt); + } + + template + friend auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type; + template + friend typename Distr::result_type generate_single(Distr& distr, Engine& engine); +}; + +} // namespace oneapi::mkl::rng::device + +#endif // _MKL_RNG_DEVICE_DISTRIBUTIONS_HPP_ diff --git a/include/oneapi/mkl/rng/device/engines.hpp b/include/oneapi/mkl/rng/device/engines.hpp new file mode 100644 index 000000000..f1bcfd1b0 --- /dev/null +++ b/include/oneapi/mkl/rng/device/engines.hpp @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_ENGINES_HPP_ +#define _MKL_RNG_DEVICE_ENGINES_HPP_ + +#include + +#include "oneapi/mkl/rng/device/types.hpp" +#include "oneapi/mkl/rng/device/functions.hpp" +#include "oneapi/mkl/rng/device/detail/engine_base.hpp" + +namespace oneapi::mkl::rng::device { + +// PSEUDO-RANDOM NUMBER DEVICE-SIDE ENGINES + +// Class template oneapi::mkl::rng::device::philox4x32x10 +// +// Represents Philox4x32-10 counter-based pseudorandom number generator +// +// Supported parallelization methods: +// skip_ahead +// +template +class philox4x32x10 : detail::engine_base> { +public: + static constexpr std::uint64_t default_seed = 0; + + static constexpr std::int32_t vec_size = VecSize; + + philox4x32x10() : detail::engine_base>(default_seed) {} + + philox4x32x10(std::uint64_t seed, std::uint64_t offset = 0) + : detail::engine_base>(seed, offset) {} + + philox4x32x10(std::initializer_list seed, std::uint64_t offset = 0) + : detail::engine_base>(seed.size(), seed.begin(), offset) {} + + philox4x32x10(std::uint64_t seed, std::initializer_list offset) + : detail::engine_base>(seed, offset.size(), offset.begin()) {} + + philox4x32x10(std::initializer_list seed, + std::initializer_list offset) + : detail::engine_base>(seed.size(), seed.begin(), offset.size(), + offset.begin()) {} + +private: + template + friend void skip_ahead(Engine& engine, std::uint64_t num_to_skip); + + template + friend void skip_ahead(Engine& engine, std::initializer_list num_to_skip); + + template + friend class detail::distribution_base; +}; + +// Class oneapi::mkl::rng::device::mrg32k3a +// +// Represents the combined recurcive pseudorandom number generator +// +// Supported parallelization methods: +// skip_ahead +// +template +class mrg32k3a : detail::engine_base> { +public: + static constexpr std::uint32_t default_seed = 1; + + static constexpr std::int32_t vec_size = VecSize; + + mrg32k3a() : detail::engine_base>(default_seed) {} + + mrg32k3a(std::uint32_t seed, std::uint64_t offset = 0) + : detail::engine_base>(seed, offset) {} + + mrg32k3a(std::initializer_list seed, std::uint64_t offset = 0) + : detail::engine_base>(seed.size(), seed.begin(), offset) {} + + mrg32k3a(std::uint32_t seed, std::initializer_list offset) + : detail::engine_base>(seed, offset.size(), offset.begin()) {} + + mrg32k3a(std::initializer_list seed, std::initializer_list offset) + : detail::engine_base>(seed.size(), seed.begin(), offset.size(), + offset.begin()) {} + +private: + template + friend void skip_ahead(Engine& engine, std::uint64_t num_to_skip); + + template + friend void skip_ahead(Engine& engine, std::initializer_list num_to_skip); + + template + friend class detail::distribution_base; +}; + +// Class oneapi::mkl::rng::device::mcg31m1 +// +// +// +// Supported parallelization methods: +// skip_ahead +// +template +class mcg31m1 : detail::engine_base> { +public: + static constexpr std::uint32_t default_seed = 1; + + static constexpr std::int32_t vec_size = VecSize; + + mcg31m1() : detail::engine_base>(default_seed) {} + + mcg31m1(std::uint32_t seed, std::uint64_t offset = 0) + : detail::engine_base>(seed, offset) {} + +private: + template + friend void skip_ahead(Engine& engine, std::uint64_t num_to_skip); + + template + friend class detail::distribution_base; +}; + +// Class oneapi::mkl::rng::device::mcg59 +// +// +// +// Supported parallelization methods: +// skip_ahead +// +template +class mcg59 : detail::engine_base> { +public: + static constexpr std::uint32_t default_seed = 1; + + static constexpr std::int32_t vec_size = VecSize; + + mcg59() : detail::engine_base>(default_seed) {} + + mcg59(std::uint64_t seed, std::uint64_t offset = 0) + : detail::engine_base>(seed, offset) {} + +private: + template + friend void skip_ahead(Engine& engine, std::uint64_t num_to_skip); + + template + friend class detail::distribution_base; +}; + +} // namespace oneapi::mkl::rng::device + +#endif // _MKL_RNG_DEVICE_ENGINES_HPP_ diff --git a/include/oneapi/mkl/rng/device/functions.hpp b/include/oneapi/mkl/rng/device/functions.hpp new file mode 100644 index 000000000..d8542b836 --- /dev/null +++ b/include/oneapi/mkl/rng/device/functions.hpp @@ -0,0 +1,52 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_FUNCTIONS_HPP_ +#define _MKL_RNG_DEVICE_FUNCTIONS_HPP_ + +#include + +#include "oneapi/mkl/rng/device/detail/distribution_base.hpp" + +namespace oneapi::mkl::rng::device { + +// GENERATE FUNCTIONS + +template +auto generate(Distr& distr, Engine& engine) -> + typename std::conditional>::type { + return distr.generate(engine); +} + +// SERVICE FUNCTIONS + +template +void skip_ahead(Engine& engine, std::uint64_t num_to_skip) { + engine.skip_ahead(num_to_skip); +} + +template +void skip_ahead(Engine& engine, std::initializer_list num_to_skip) { + engine.skip_ahead(num_to_skip); +} + +} // namespace oneapi::mkl::rng::device + +#endif // _MKL_RNG_DEVICE_FUNCTIONS_HPP_ diff --git a/include/oneapi/mkl/rng/device/types.hpp b/include/oneapi/mkl/rng/device/types.hpp new file mode 100644 index 000000000..e5f74e25b --- /dev/null +++ b/include/oneapi/mkl/rng/device/types.hpp @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _MKL_RNG_DEVICE_TYPES_HPP_ +#define _MKL_RNG_DEVICE_TYPES_HPP_ + +namespace oneapi::mkl::rng::device { + +// METHODS FOR DISTRIBUTIONS + +namespace uniform_method { +struct standard {}; +struct accurate {}; +using by_default = standard; +} // namespace uniform_method + +namespace gaussian_method { +struct box_muller2 {}; +struct icdf {}; +using by_default = box_muller2; +} // namespace gaussian_method + +namespace lognormal_method { +struct box_muller2 {}; +using by_default = box_muller2; +} // namespace lognormal_method + +namespace exponential_method { +struct icdf {}; +struct icdf_accurate {}; +using by_default = icdf; +} // namespace exponential_method + +namespace poisson_method { +struct devroye {}; +using by_default = devroye; +} // namespace poisson_method + +namespace bernoulli_method { +struct icdf {}; +using by_default = icdf; +} // namespace bernoulli_method + +} // namespace oneapi::mkl::rng::device + +#endif // _MKL_RNG_DEVICE_TYPES_HPP_ diff --git a/include/oneapi/mkl/rng/predicates.hpp b/include/oneapi/mkl/rng/predicates.hpp index ee942dc89..10422e543 100644 --- a/include/oneapi/mkl/rng/predicates.hpp +++ b/include/oneapi/mkl/rng/predicates.hpp @@ -37,7 +37,7 @@ namespace rng { // Buffer APIs template -inline void generate_precondition(const Distr& distr, Engine& engine, std::int64_t n, +inline void generate_precondition(const Distr& /*distr*/, Engine& /*engine*/, std::int64_t n, sycl::buffer& r) { #ifndef ONEMKL_DISABLE_PREDICATES if (n < 0 || n > r.size()) { @@ -49,9 +49,9 @@ inline void generate_precondition(const Distr& distr, Engine& engine, std::int64 // USM APIs template -inline void generate_precondition(const Distr& distr, Engine& engine, std::int64_t n, +inline void generate_precondition(const Distr& /*distr*/, Engine& /*engine*/, std::int64_t n, typename Distr::result_type* r, - const std::vector& dependencies) { + const std::vector& /*dependencies*/) { #ifndef ONEMKL_DISABLE_PREDICATES if (n < 0) { throw oneapi::mkl::invalid_argument("rng", "generate", "n"); diff --git a/include/oneapi/mkl/sparse_blas.hpp b/include/oneapi/mkl/sparse_blas.hpp new file mode 100644 index 000000000..912a20eb8 --- /dev/null +++ b/include/oneapi/mkl/sparse_blas.hpp @@ -0,0 +1,40 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_HPP_ +#define _ONEMKL_SPARSE_BLAS_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/detail/config.hpp" + +#ifdef ENABLE_MKLCPU_BACKEND +#include "sparse_blas/detail/mklcpu/sparse_blas_ct.hpp" +#endif +#ifdef ENABLE_MKLGPU_BACKEND +#include "sparse_blas/detail/mklgpu/sparse_blas_ct.hpp" +#endif + +#include "sparse_blas/detail/sparse_blas_rt.hpp" + +#endif // _ONEMKL_SPARSE_BLAS_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/helper_types.hpp b/include/oneapi/mkl/sparse_blas/detail/helper_types.hpp new file mode 100644 index 000000000..4964b1eff --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/helper_types.hpp @@ -0,0 +1,52 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_HELPER_TYPES_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_HELPER_TYPES_HPP_ + +#include +#include +#include + +namespace oneapi { +namespace mkl { +namespace sparse { +namespace detail { + +struct matrix_handle; + +template +inline constexpr bool is_fp_supported_v = + std::is_same_v || std::is_same_v || + std::is_same_v> || std::is_same_v>; + +template +inline constexpr bool is_int_supported_v = + std::is_same_v || std::is_same_v; + +template +inline constexpr bool are_fp_int_supported_v = + is_fp_supported_v&& is_int_supported_v; + +} // namespace detail +} // namespace sparse +} // namespace mkl +} // namespace oneapi + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_HELPER_TYPES_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp b/include/oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp new file mode 100644 index 000000000..2535e61f6 --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp @@ -0,0 +1,34 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_MKLCPU_ONEMKL_SPARSE_BLAS_MKLCPU_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_MKLCPU_ONEMKL_SPARSE_BLAS_MKLCPU_HPP_ + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/sparse_blas/detail/helper_types.hpp" + +namespace oneapi::mkl::sparse::mklcpu { + +namespace detail = oneapi::mkl::sparse::detail; + +#include "oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx" + +} // namespace oneapi::mkl::sparse::mklcpu + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_MKLCPU_ONEMKL_SPARSE_BLAS_MKLCPU_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/mklcpu/sparse_blas_ct.hpp b/include/oneapi/mkl/sparse_blas/detail/mklcpu/sparse_blas_ct.hpp new file mode 100644 index 000000000..bc0089c57 --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/mklcpu/sparse_blas_ct.hpp @@ -0,0 +1,41 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_MKLCPU_SPARSE_BLAS_CT_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_MKLCPU_SPARSE_BLAS_CT_HPP_ + +#include "oneapi/mkl/sparse_blas/types.hpp" +#include "oneapi/mkl/detail/backends.hpp" +#include "oneapi/mkl/detail/backend_selector.hpp" + +#include "onemkl_sparse_blas_mklcpu.hpp" + +namespace oneapi { +namespace mkl { +namespace sparse { + +#define BACKEND mklcpu +#include "oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx" +#undef BACKEND + +} //namespace sparse +} //namespace mkl +} //namespace oneapi + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_MKLCPU_SPARSE_BLAS_CT_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp b/include/oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp new file mode 100644 index 000000000..1ca336b9b --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp @@ -0,0 +1,34 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_MKLGPU_ONEMKL_SPARSE_BLAS_MKLGPU_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_MKLGPU_ONEMKL_SPARSE_BLAS_MKLGPU_HPP_ + +#include "oneapi/mkl/detail/export.hpp" +#include "oneapi/mkl/sparse_blas/detail/helper_types.hpp" + +namespace oneapi::mkl::sparse::mklgpu { + +namespace detail = oneapi::mkl::sparse::detail; + +#include "oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx" + +} // namespace oneapi::mkl::sparse::mklgpu + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_MKLGPU_ONEMKL_SPARSE_BLAS_MKLGPU_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/mklgpu/sparse_blas_ct.hpp b/include/oneapi/mkl/sparse_blas/detail/mklgpu/sparse_blas_ct.hpp new file mode 100644 index 000000000..00c01346f --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/mklgpu/sparse_blas_ct.hpp @@ -0,0 +1,41 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_MKLGPU_SPARSE_BLAS_CT_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_MKLGPU_SPARSE_BLAS_CT_HPP_ + +#include "oneapi/mkl/sparse_blas/types.hpp" +#include "oneapi/mkl/detail/backends.hpp" +#include "oneapi/mkl/detail/backend_selector.hpp" + +#include "onemkl_sparse_blas_mklgpu.hpp" + +namespace oneapi { +namespace mkl { +namespace sparse { + +#define BACKEND mklgpu +#include "oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx" +#undef BACKEND + +} //namespace sparse +} //namespace mkl +} //namespace oneapi + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_MKLGPU_SPARSE_BLAS_CT_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx b/include/oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx new file mode 100644 index 000000000..03beaa4b4 --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/onemkl_sparse_blas_backends.hxx @@ -0,0 +1,91 @@ +/*************************************************************************** +* Copyright(C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0(the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +// This file is meant to be included in each backend onemkl_sparse_blas_BACKEND.hpp files. +// It is used to exports each symbol to the onemkl_sparse_blas_BACKEND library. + +ONEMKL_EXPORT void init_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle); + +ONEMKL_EXPORT sycl::event release_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle, + const std::vector &dependencies = {}); + +template +ONEMKL_EXPORT std::enable_if_t> set_csr_data( + sycl::queue &queue, matrix_handle_t handle, intType num_rows, intType num_cols, intType nnz, + index_base index, sycl::buffer &row_ptr, sycl::buffer &col_ind, + sycl::buffer &val); + +template +ONEMKL_EXPORT std::enable_if_t, sycl::event> +set_csr_data(sycl::queue &queue, matrix_handle_t handle, intType num_rows, intType num_cols, + intType nnz, index_base index, intType *row_ptr, intType *col_ind, fpType *val, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, + matrix_handle_t handle, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, + transpose transpose_B, layout dense_matrix_layout, + const std::int64_t columns, matrix_handle_t handle, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event optimize_gemv(sycl::queue &queue, transpose transpose_val, + matrix_handle_t handle, + const std::vector &dependencies = {}); + +ONEMKL_EXPORT sycl::event optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, + diag diag_val, matrix_handle_t handle, + const std::vector &dependencies = {}); + +template +ONEMKL_EXPORT std::enable_if_t> gemv( + sycl::queue &queue, transpose transpose_val, const fpType alpha, matrix_handle_t A_handle, + sycl::buffer &x, const fpType beta, sycl::buffer &y); + +template +ONEMKL_EXPORT std::enable_if_t, sycl::event> gemv( + sycl::queue &queue, transpose transpose_val, const fpType alpha, matrix_handle_t A_handle, + const fpType *x, const fpType beta, fpType *y, + const std::vector &dependencies = {}); + +template +ONEMKL_EXPORT std::enable_if_t> trsv( + sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, sycl::buffer &x, sycl::buffer &y); + +template +ONEMKL_EXPORT std::enable_if_t, sycl::event> trsv( + sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, const fpType *x, fpType *y, + const std::vector &dependencies = {}); + +template +ONEMKL_EXPORT std::enable_if_t> gemm( + sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, transpose transpose_B, + const fpType alpha, matrix_handle_t A_handle, sycl::buffer &B, + const std::int64_t columns, const std::int64_t ldb, const fpType beta, + sycl::buffer &C, const std::int64_t ldc); + +template +ONEMKL_EXPORT std::enable_if_t, sycl::event> gemm( + sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, transpose transpose_B, + const fpType alpha, matrix_handle_t A_handle, const fpType *B, const std::int64_t columns, + const std::int64_t ldb, const fpType beta, fpType *C, const std::int64_t ldc, + const std::vector &dependencies = {}); diff --git a/include/oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx b/include/oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx new file mode 100644 index 000000000..41fe51c49 --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/sparse_blas_ct.hxx @@ -0,0 +1,135 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +// This file is meant to be included in each backend sparse_blas_ct.hpp files +// Each function calls the implementation from onemkl_sparse_blas_backends.hxx + +#ifndef BACKEND +#error "BACKEND is not defined" +#endif + +inline void init_matrix_handle(backend_selector selector, + matrix_handle_t *p_handle) { + BACKEND::init_matrix_handle(selector.get_queue(), p_handle); +} + +inline sycl::event release_matrix_handle(backend_selector selector, + matrix_handle_t *p_handle, + const std::vector &dependencies = {}) { + return BACKEND::release_matrix_handle(selector.get_queue(), p_handle, dependencies); +} + +template +std::enable_if_t> set_csr_data( + backend_selector selector, matrix_handle_t handle, intType num_rows, + intType num_cols, intType nnz, index_base index, sycl::buffer &row_ptr, + sycl::buffer &col_ind, sycl::buffer &val) { + BACKEND::set_csr_data(selector.get_queue(), handle, num_rows, num_cols, nnz, index, row_ptr, + col_ind, val); +} + +template +std::enable_if_t, sycl::event> set_csr_data( + backend_selector selector, matrix_handle_t handle, intType num_rows, + intType num_cols, intType nnz, index_base index, intType *row_ptr, intType *col_ind, + fpType *val, const std::vector &dependencies = {}) { + return BACKEND::set_csr_data(selector.get_queue(), handle, num_rows, num_cols, nnz, index, + row_ptr, col_ind, val, dependencies); +} + +inline sycl::event optimize_gemm(backend_selector selector, transpose transpose_A, + matrix_handle_t handle, + const std::vector &dependencies = {}) { + return BACKEND::optimize_gemm(selector.get_queue(), transpose_A, handle, dependencies); +} + +inline sycl::event optimize_gemm(backend_selector selector, transpose transpose_A, + transpose transpose_B, layout dense_matrix_layout, + const std::int64_t columns, matrix_handle_t handle, + const std::vector &dependencies = {}) { + return BACKEND::optimize_gemm(selector.get_queue(), transpose_A, transpose_B, + dense_matrix_layout, columns, handle, dependencies); +} + +inline sycl::event optimize_gemv(backend_selector selector, + transpose transpose_val, matrix_handle_t handle, + const std::vector &dependencies = {}) { + return BACKEND::optimize_gemv(selector.get_queue(), transpose_val, handle, dependencies); +} + +inline sycl::event optimize_trsv(backend_selector selector, uplo uplo_val, + transpose transpose_val, diag diag_val, matrix_handle_t handle, + const std::vector &dependencies = {}) { + return BACKEND::optimize_trsv(selector.get_queue(), uplo_val, transpose_val, diag_val, handle, + dependencies); +} + +template +std::enable_if_t> gemv( + backend_selector selector, transpose transpose_val, const fpType alpha, + matrix_handle_t A_handle, sycl::buffer &x, const fpType beta, + sycl::buffer &y) { + BACKEND::gemv(selector.get_queue(), transpose_val, alpha, A_handle, x, beta, y); +} + +template +std::enable_if_t, sycl::event> gemv( + backend_selector selector, transpose transpose_val, const fpType alpha, + matrix_handle_t A_handle, const fpType *x, const fpType beta, fpType *y, + const std::vector &dependencies = {}) { + return BACKEND::gemv(selector.get_queue(), transpose_val, alpha, A_handle, x, beta, y, + dependencies); +} + +template +std::enable_if_t> trsv( + backend_selector selector, uplo uplo_val, transpose transpose_val, + diag diag_val, matrix_handle_t A_handle, sycl::buffer &x, + sycl::buffer &y) { + BACKEND::trsv(selector.get_queue(), uplo_val, transpose_val, diag_val, A_handle, x, y); +} + +template +std::enable_if_t, sycl::event> trsv( + backend_selector selector, uplo uplo_val, transpose transpose_val, + diag diag_val, matrix_handle_t A_handle, const fpType *x, fpType *y, + const std::vector &dependencies = {}) { + return BACKEND::trsv(selector.get_queue(), uplo_val, transpose_val, diag_val, A_handle, x, y, + dependencies); +} + +template +std::enable_if_t> gemm( + backend_selector selector, layout dense_matrix_layout, transpose transpose_A, + transpose transpose_B, const fpType alpha, matrix_handle_t A_handle, sycl::buffer &B, + const std::int64_t columns, const std::int64_t ldb, const fpType beta, + sycl::buffer &C, const std::int64_t ldc) { + BACKEND::gemm(selector.get_queue(), dense_matrix_layout, transpose_A, transpose_B, alpha, + A_handle, B, columns, ldb, beta, C, ldc); +} + +template +std::enable_if_t, sycl::event> gemm( + backend_selector selector, layout dense_matrix_layout, transpose transpose_A, + transpose transpose_B, const fpType alpha, matrix_handle_t A_handle, const fpType *B, + const std::int64_t columns, const std::int64_t ldb, const fpType beta, fpType *C, + const std::int64_t ldc, const std::vector &dependencies = {}) { + return BACKEND::gemm(selector.get_queue(), dense_matrix_layout, transpose_A, transpose_B, alpha, + A_handle, B, columns, ldb, beta, C, ldc, dependencies); +} diff --git a/include/oneapi/mkl/sparse_blas/detail/sparse_blas_rt.hpp b/include/oneapi/mkl/sparse_blas/detail/sparse_blas_rt.hpp new file mode 100644 index 000000000..131e0545a --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/detail/sparse_blas_rt.hpp @@ -0,0 +1,103 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_DETAIL_SPARSE_BLAS_RT_HPP_ +#define _ONEMKL_SPARSE_BLAS_DETAIL_SPARSE_BLAS_RT_HPP_ + +#include "oneapi/mkl/sparse_blas/types.hpp" + +namespace oneapi { +namespace mkl { +namespace sparse { + +void init_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle); + +sycl::event release_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle, + const std::vector &dependencies = {}); + +template +std::enable_if_t> set_csr_data( + sycl::queue &queue, matrix_handle_t handle, intType num_rows, intType num_cols, intType nnz, + index_base index, sycl::buffer &row_ptr, sycl::buffer &col_ind, + sycl::buffer &val); + +template +std::enable_if_t, sycl::event> set_csr_data( + sycl::queue &queue, matrix_handle_t handle, intType num_rows, intType num_cols, intType nnz, + index_base index, intType *row_ptr, intType *col_ind, fpType *val, + const std::vector &dependencies = {}); + +sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, matrix_handle_t handle, + const std::vector &dependencies = {}); + +sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, transpose transpose_B, + layout dense_matrix_layout, const std::int64_t columns, + matrix_handle_t handle, + const std::vector &dependencies = {}); + +sycl::event optimize_gemv(sycl::queue &queue, transpose transpose_val, matrix_handle_t handle, + const std::vector &dependencies = {}); + +sycl::event optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t handle, + const std::vector &dependencies = {}); + +template +std::enable_if_t> gemv( + sycl::queue &queue, transpose transpose_val, const fpType alpha, matrix_handle_t A_handle, + sycl::buffer &x, const fpType beta, sycl::buffer &y); + +template +std::enable_if_t, sycl::event> gemv( + sycl::queue &queue, transpose transpose_val, const fpType alpha, matrix_handle_t A_handle, + const fpType *x, const fpType beta, fpType *y, + const std::vector &dependencies = {}); + +template +std::enable_if_t> trsv(sycl::queue &queue, uplo uplo_val, + transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, + sycl::buffer &x, + sycl::buffer &y); + +template +std::enable_if_t, sycl::event> trsv( + sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t A_handle, const fpType *x, fpType *y, + const std::vector &dependencies = {}); + +template +std::enable_if_t> gemm( + sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, transpose transpose_B, + const fpType alpha, matrix_handle_t A_handle, sycl::buffer &B, + const std::int64_t columns, const std::int64_t ldb, const fpType beta, + sycl::buffer &C, const std::int64_t ldc); + +template +std::enable_if_t, sycl::event> gemm( + sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, transpose transpose_B, + const fpType alpha, matrix_handle_t A_handle, const fpType *B, const std::int64_t columns, + const std::int64_t ldb, const fpType beta, fpType *C, const std::int64_t ldc, + const std::vector &dependencies = {}); + +} // namespace sparse +} // namespace mkl +} // namespace oneapi + +#endif // _ONEMKL_SPARSE_BLAS_DETAIL_SPARSE_BLAS_RT_HPP_ diff --git a/include/oneapi/mkl/sparse_blas/types.hpp b/include/oneapi/mkl/sparse_blas/types.hpp new file mode 100644 index 000000000..406c7dd1f --- /dev/null +++ b/include/oneapi/mkl/sparse_blas/types.hpp @@ -0,0 +1,44 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_TYPES_HPP_ +#define _ONEMKL_SPARSE_BLAS_TYPES_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include + +#include "oneapi/mkl/types.hpp" +#include "detail/helper_types.hpp" + +namespace oneapi { +namespace mkl { +namespace sparse { + +using matrix_handle_t = detail::matrix_handle*; + +} // namespace sparse +} // namespace mkl +} // namespace oneapi + +#endif // _ONEMKL_SPARSE_BLAS_TYPES_HPP_ diff --git a/include/oneapi/mkl/types.hpp b/include/oneapi/mkl/types.hpp index 67f924dde..32d336e11 100644 --- a/include/oneapi/mkl/types.hpp +++ b/include/oneapi/mkl/types.hpp @@ -20,7 +20,10 @@ #ifndef _ONEMKL_TYPES_HPP_ #define _ONEMKL_TYPES_HPP_ +#ifdef __HIPSYCL__ #include "oneapi/mkl/bfloat16.hpp" +#endif + #if __has_include() #include #else @@ -30,6 +33,10 @@ namespace oneapi { namespace mkl { +#ifndef __HIPSYCL__ +using bfloat16 = sycl::ext::oneapi::bfloat16; +#endif + // BLAS flag types. enum class transpose : char { nontrans = 0, trans = 1, conjtrans = 3, N = 0, T = 1, C = 3 }; @@ -41,7 +48,7 @@ enum class side : char { left = 0, right = 1, L = 0, R = 1 }; enum class offset : char { row = 0, column = 1, fix = 2, R = 0, C = 1, F = 2 }; -enum class layout : char { column_major = 0, row_major = 1, C = 0, R = 1 }; +enum class layout : char { row_major = 0, col_major = 1, R = 0, C = 1 }; enum class index_base : char { zero = 0, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2d73bca25..0b632c1bd 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,7 +23,24 @@ if(WIN32 AND BUILD_SHARED_LIBS) list(APPEND ONEMKL_BUILD_COPT "-Donemkl_EXPORTS") endif() +# portBLAS backend variables must be accessible here to correctly +# generate the config file. +set(ENABLE_PORTBLAS_BACKEND_INTEL_CPU OFF CACHE INTERNAL "") +set(ENABLE_PORTBLAS_BACKEND_INTEL_GPU OFF CACHE INTERNAL "") +set(ENABLE_PORTBLAS_BACKEND_AMD_GPU OFF CACHE INTERNAL "") +set(ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU OFF CACHE INTERNAL "") +# store path to CMAKE_CURRENT_BINARY_DIR to use it later (makes FetchContent_Declare workable) +set(ONEMKL_GENERATED_INCLUDE_PATH ${CMAKE_CURRENT_BINARY_DIR}) + + +set(ONEMKL_INTERFACE_INCLUDE_DIRS + $ + $ + $ +) + # Build loader and backends for each domain +add_custom_target(onemkl_backend_libs) foreach(domain ${TARGET_DOMAINS}) add_subdirectory(${domain}) endforeach() @@ -43,14 +60,22 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/oneapi/mkl/detail/config.hpp" if(BUILD_SHARED_LIBS) add_library(onemkl SHARED) + # The loader library depends on all the backend libraries as it uses + # dlopen to load them at runtime. + # Use add_dependencies to ensure that all the backend libraries are + # (re-)built when compiling the loader or runtime binaries. + add_dependencies(onemkl onemkl_backend_libs) + target_include_directories(onemkl - PUBLIC $ - $ - $ + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} ) set_target_properties(onemkl PROPERTIES SOVERSION ${PROJECT_VERSION_MAJOR} ) + # w/a for setting oneMKL Interfaces installed headers as -I instead of -isystem for cmake >= 3.25 for workable find_package(MKL) combination + if(CMAKE_VERSION VERSION_GREATER_EQUAL "3.25.0") + set_target_properties(onemkl PROPERTIES EXPORT_NO_SYSTEM true) + endif() # Build dispatcher library set (ONEMKL_LIBS ${TARGET_DOMAINS}) diff --git a/src/blas/CMakeLists.txt b/src/blas/CMakeLists.txt index 3656c5ac1..1edf2e445 100644 --- a/src/blas/CMakeLists.txt +++ b/src/blas/CMakeLists.txt @@ -29,6 +29,7 @@ target_include_directories(onemkl_blas ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/src/include ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} $ ) diff --git a/src/blas/backends/CMakeLists.txt b/src/blas/backends/CMakeLists.txt index 31dfb92be..351f4b0e5 100644 --- a/src/blas/backends/CMakeLists.txt +++ b/src/blas/backends/CMakeLists.txt @@ -17,6 +17,9 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== +add_custom_target(onemkl_backend_libs_blas) +add_dependencies(onemkl_backend_libs onemkl_backend_libs_blas) + if(ENABLE_MKLCPU_BACKEND) add_subdirectory(mklcpu) endif() @@ -36,3 +39,7 @@ endif() if(ENABLE_ROCBLAS_BACKEND AND UNIX) add_subdirectory(rocblas) endif() + +if(ENABLE_PORTBLAS_BACKEND AND UNIX) + add_subdirectory(portblas) +endif() diff --git a/src/blas/backends/backend_wrappers.cxx b/src/blas/backends/backend_wrappers.cxx index 19d38e0c5..62f6ced13 100644 --- a/src/blas/backends/backend_wrappers.cxx +++ b/src/blas/backends/backend_wrappers.cxx @@ -200,6 +200,9 @@ oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::trsm_batch, oneapi::mkl::blas::BACKEND::MAJOR::trsm_batch, oneapi::mkl::blas::BACKEND::MAJOR::trsm_batch, @@ -228,6 +231,10 @@ oneapi::mkl::blas::BACKEND::MAJOR::omatcopy, oneapi::mkl::blas::BACKEND::MAJOR::omatcopy, oneapi::mkl::blas::BACKEND::MAJOR::omatcopy, oneapi::mkl::blas::BACKEND::MAJOR::omatcopy, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy2, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy2, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy2, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy2, oneapi::mkl::blas::BACKEND::MAJOR::imatcopy, oneapi::mkl::blas::BACKEND::MAJOR::imatcopy, oneapi::mkl::blas::BACKEND::MAJOR::imatcopy, @@ -451,6 +458,12 @@ oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, +oneapi::mkl::blas::BACKEND::MAJOR::gemm_batch, oneapi::mkl::blas::BACKEND::MAJOR::gemmt, oneapi::mkl::blas::BACKEND::MAJOR::gemmt, oneapi::mkl::blas::BACKEND::MAJOR::gemmt, @@ -475,6 +488,10 @@ oneapi::mkl::blas::BACKEND::MAJOR::omatcopy, oneapi::mkl::blas::BACKEND::MAJOR::omatcopy, oneapi::mkl::blas::BACKEND::MAJOR::omatcopy, oneapi::mkl::blas::BACKEND::MAJOR::omatcopy, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy2, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy2, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy2, +oneapi::mkl::blas::BACKEND::MAJOR::omatcopy2, oneapi::mkl::blas::BACKEND::MAJOR::imatcopy, oneapi::mkl::blas::BACKEND::MAJOR::imatcopy, oneapi::mkl::blas::BACKEND::MAJOR::imatcopy, diff --git a/src/blas/backends/cublas/CMakeLists.txt b/src/blas/backends/cublas/CMakeLists.txt index 1b54ccdbf..b64e7c37d 100644 --- a/src/blas/backends/cublas/CMakeLists.txt +++ b/src/blas/backends/cublas/CMakeLists.txt @@ -30,13 +30,22 @@ set(SOURCES cublas_level1.cpp $<$: cublas_wrappers.cpp>) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemkl_backend_libs_blas ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/src + ${ONEMKL_GENERATED_INCLUDE_PATH} ) target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + +if(NOT ${ONEMKL_SYCL_IMPLEMENTATION} STREQUAL "hipsycl") + target_compile_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=nvptx64-nvidia-cuda -fsycl-unnamed-lambda) + target_link_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=nvptx64-nvidia-cuda) +endif() target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ONEMKL::cuBLAS::cuBLAS) target_compile_features(${LIB_OBJ} PUBLIC cxx_std_11) set_target_properties(${LIB_OBJ} PROPERTIES diff --git a/src/blas/backends/cublas/cublas_batch.cpp b/src/blas/backends/cublas/cublas_batch.cpp index beefd6eeb..009bb9541 100644 --- a/src/blas/backends/cublas/cublas_batch.cpp +++ b/src/blas/backends/cublas/cublas_batch.cpp @@ -140,16 +140,21 @@ void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -template -inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, int64_t k, T alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, - int64_t ldb, int64_t stride_b, T beta, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - using cuDataType = typename CudaEquivalentType::Type; +template +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { + if (!verify_support(queue, sycl::aspect::fp16)) { throw oneapi::mkl::unimplemented( "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); } @@ -158,33 +163,56 @@ inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, tra auto c_acc = c.template get_access(cgh); onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); - auto c_ = sc.get_mem(c_acc); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, - a_, lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_, - ldc, stride_c, batch_size); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a_, + get_cublas_datatype(), lda, stride_a, b_, get_cublas_datatype(), + ldb, stride_b, &beta, c_, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); }); }); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size) { \ - gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, \ - stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); \ +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, \ + lda, stride_a, b, ldb, stride_b, beta, c, \ + ldc, stride_c, batch_size); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) + +#undef GEMM_STRIDED_BATCH_LAUNCHER + +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -553,17 +581,23 @@ sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t throw unimplemented("blas", "dgmm_batch", "for column_major layout"); } -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, - T alpha, const T *a, int64_t lda, int64_t stride_a, const T *b, - int64_t ldb, int64_t stride_b, T beta, T *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - using cuDataType = typename CudaEquivalentType::Type; +template +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stride_a, + const Tb *b, int64_t ldb, int64_t stride_b, Ts beta, + Tc *c, int64_t ldc, int64_t stride_c, + int64_t batch_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { + if (!verify_support(queue, sycl::aspect::fp16)) { throw oneapi::mkl::unimplemented( "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); } @@ -573,50 +607,74 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que } onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); - auto c_ = reinterpret_cast(c); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), - get_cublas_operation(transb), m, n, k, (cuDataType *)&alpha, - a_, lda, stride_a, b_, ldb, stride_b, (cuDataType *)&beta, c_, - ldc, stride_c, batch_size); + CUBLAS_ERROR_FUNC_T_SYNC( + "cublasGemmStridedBatchedEx", cublasGemmStridedBatchedEx, err, handle, + get_cublas_operation(transa), get_cublas_operation(transb), m, n, k, &alpha, a, + get_cublas_datatype(), lda, stride_a, b, get_cublas_datatype(), + ldb, stride_b, &beta, c, get_cublas_datatype(), ldc, stride_c, batch_size, + get_cublas_datatype(), cublas_gemm_algo); }); }); return done; } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stride_a, const TYPE *b, int64_t ldb, int64_t stride_b, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stride_c, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, \ - dependencies); \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, \ + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, \ + batch_size, dependencies); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose *transa, transpose *transb, int64_t *m, int64_t *n, - int64_t *k, T *alpha, const T **a, int64_t *lda, const T **b, - int64_t *ldb, T *beta, T **c, int64_t *ldc, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - using cuDataType = typename CudaEquivalentType::Type; +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using cuTypeA = typename CudaEquivalentType::Type; + using cuTypeB = typename CudaEquivalentType::Type; + using cuTypeC = typename CudaEquivalentType::Type; + using cuTypeS = typename CudaEquivalentType::Type; for (int64_t i = 0; i < group_count; i++) { overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); } + + cublasGemmAlgo_t cublas_gemm_algo = CUBLAS_GEMM_DEFAULT; auto done = queue.submit([&](sycl::handler &cgh) { - if (!verify_support(queue, sycl::aspect::fp16)) { + if (!verify_support(queue, sycl::aspect::fp16)) { throw oneapi::mkl::unimplemented( "blas", "sycl::half", "half is not supported by the device or the sycl compiler"); } @@ -629,14 +687,14 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que int64_t offset = 0; cublasStatus_t err; for (int64_t i = 0; i < group_count; i++) { - auto **a_ = reinterpret_cast(a); - auto **b_ = reinterpret_cast(b); - auto **c_ = reinterpret_cast(c); CUBLAS_ERROR_FUNC_T_SYNC( - func_name, func, err, handle, get_cublas_operation(transa[i]), - get_cublas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], - (cuDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], - (cuDataType *)&beta[i], c_ + offset, (int)ldc[i], (int)group_size[i]); + "cublasGemmBatchedEx", cublasGemmBatchedEx, err, handle, + get_cublas_operation(transa[i]), get_cublas_operation(transb[i]), (int)m[i], + (int)n[i], (int)k[i], &alpha[i], (const void *const *)(a + offset), + get_cublas_datatype(), (int)lda[i], (const void *const *)(b + offset), + get_cublas_datatype(), (int)ldb[i], &beta[i], + (void *const *)(c + offset), get_cublas_datatype(), (int)ldc[i], + (int)group_size[i], get_cublas_datatype(), cublas_gemm_algo); offset += group_size[i]; } }); @@ -644,21 +702,41 @@ inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &que return done; } -#define GEMM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmBatched) -GEMM_BATCH_LAUNCHER_USM(float, cublasSgemmBatched) -GEMM_BATCH_LAUNCHER_USM(double, cublasDgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasCgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasZgemmBatched) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) + +#undef GEMM_BATCH_LAUNCHER_USM + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) #undef GEMM_BATCH_LAUNCHER_USM @@ -1066,30 +1144,25 @@ void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -template -inline void gemm_batch(const char *func_name, Func func, sycl::queue &queue, transpose transa, - transpose transb, int64_t m, int64_t n, int64_t k, T alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, - int64_t ldb, int64_t stride_b, T beta, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); -} - -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ - void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ - int64_t batch_size) { \ - gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, \ - stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); \ +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER @@ -1458,59 +1531,47 @@ sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t throw unimplemented("blas", "dgmm_batch", "for row_major layout"); } -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, - T alpha, const T *a, int64_t lda, int64_t stride_a, const T *b, - int64_t ldb, int64_t stride_b, T beta, T *c, int64_t ldc, - int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); -} - -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stride_a, const TYPE *b, int64_t ldb, int64_t stride_b, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stride_c, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, \ - dependencies); \ +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stride_a, const TYPE_B *b, int64_t ldb, int64_t stride_b, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stride_c, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, cublasSgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, cublasDgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasCgemmStridedBatched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, cublasZgemmStridedBatched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(const char *func_name, Func func, sycl::queue &queue, - transpose *transa, transpose *transb, int64_t *m, int64_t *n, - int64_t *k, T *alpha, const T **a, int64_t *lda, const T **b, - int64_t *ldb, T *beta, T **c, int64_t *ldc, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); -} - -#define GEMM_BATCH_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, \ - a, lda, b, ldb, beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", "for row_major layout"); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, cublasHgemmBatched) -GEMM_BATCH_LAUNCHER_USM(float, cublasSgemmBatched) -GEMM_BATCH_LAUNCHER_USM(double, cublasDgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasCgemmBatched) -GEMM_BATCH_LAUNCHER_USM(std::complex, cublasZgemmBatched) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) #undef GEMM_BATCH_LAUNCHER_USM diff --git a/src/blas/backends/cublas/cublas_extensions.cpp b/src/blas/backends/cublas/cublas_extensions.cpp index 6a462444a..cc80b483d 100644 --- a/src/blas/backends/cublas/cublas_extensions.cpp +++ b/src/blas/backends/cublas/cublas_extensions.cpp @@ -85,27 +85,62 @@ void gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose tra throw unimplemented("blas", "gemmt", "for column_major layout"); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +template +void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? m : n); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? n : m); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), + get_cublas_operation(trans), logical_m, logical_n, + (cuDataType *)&alpha, a_, lda, nullptr, nullptr, lda, b_, ldb); + }); + }); +} + +#define OMATCOPY_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { \ + omatcopy(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb); \ + } + +OMATCOPY_LAUNCHER(float, cublasSgeam) +OMATCOPY_LAUNCHER(double, cublasDgeam) +OMATCOPY_LAUNCHER(std::complex, cublasCgeam) +OMATCOPY_LAUNCHER(std::complex, cublasZgeam) + +#undef OMATCOPY_LAUNCHER + +template +void omatcopy2(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb) { + throw unimplemented("blas", "omatcopy2", ""); +} + +#define OMATCOPY2_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &b, int64_t ldb, int64_t strideb) { \ + omatcopy2(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, stridea, lda, b, \ + ldb, strideb); \ + } + +OMATCOPY2_LAUNCHER(float, "unimplemented") +OMATCOPY2_LAUNCHER(double, "unimplemented") +OMATCOPY2_LAUNCHER(std::complex, "unimplemented") +OMATCOPY2_LAUNCHER(std::complex, "unimplemented") +#undef OMATCOPY2_LAUNCHER void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { @@ -127,31 +162,43 @@ void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co throw unimplemented("blas", "imatcopy", "for column_major layout"); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, float beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, double beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +template +void omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, + T beta, sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, (cuDataType *)&alpha, a_, + lda, (cuDataType *)&beta, b_, ldb, c_, ldc); + }); + }); +} + +#define OMATADD_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, \ + b, ldb, c, ldc); \ + } + +OMATADD_LAUNCHER(float, cublasSgeam) +OMATADD_LAUNCHER(double, cublasDgeam) +OMATADD_LAUNCHER(std::complex, cublasCgeam) +OMATADD_LAUNCHER(std::complex, cublasZgeam) + +#undef OMATADD_LAUNCHER // USM APIs @@ -217,31 +264,64 @@ sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transp throw unimplemented("blas", "gemmt", "for column_major layout"); } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, float *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, double *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, +template +sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, T *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? m : n); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? n : m); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + cublasStatus_t err; + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), + get_cublas_operation(trans), logical_m, logical_n, + (cuDataType *)&alpha, a_, lda, nullptr, nullptr, lda, b_, ldb); + }); + }); + return done; +} + +#define OMATCOPY_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, TYPE *b, int64_t ldb, \ + const std::vector &dependencies) { \ + return omatcopy(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, \ + ldb, dependencies); \ + } + +OMATCOPY_LAUNCHER_USM(float, cublasSgeam) +OMATCOPY_LAUNCHER_USM(double, cublasDgeam) +OMATCOPY_LAUNCHER_USM(std::complex, cublasCgeam) +OMATCOPY_LAUNCHER_USM(std::complex, cublasZgeam) + +#undef OMATCOPY_LAUNCHER_USM + +template +sycl::event omatcopy2(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, int64_t stridea, T *b, + int64_t ldb, int64_t strideb, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy2", ""); +} + +#define OMATCOPY2_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, int64_t stridea, TYPE *b, int64_t ldb, \ + int64_t strideb, const std::vector &dependencies) { \ + return omatcopy2(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, stridea, \ + lda, b, ldb, strideb, dependencies); \ + } + +OMATCOPY2_LAUNCHER_USM(float, "unimplemented") +OMATCOPY2_LAUNCHER_USM(double, "unimplemented") +OMATCOPY2_LAUNCHER_USM(std::complex, "unimplemented") +OMATCOPY2_LAUNCHER_USM(std::complex, "unimplemented") +#undef OMATCOPY2_LAUNCHER_USM sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, @@ -267,37 +347,47 @@ sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy", "for column_major layout"); } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, float beta, const float *b, - int64_t ldb, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, double beta, const double *b, - int64_t ldb, double *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +template +inline sycl::event omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, + T beta, const T *b, int64_t ldb, T *c, int64_t ldc, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + cublasStatus_t err; + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), m, n, (cuDataType *)&alpha, a_, + lda, (cuDataType *)&beta, b_, ldb, c_, ldc); + }); + }); + return done; +} + +#define OMATADD_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, TYPE alpha, const TYPE *a, int64_t lda, TYPE beta, \ + const TYPE *b, int64_t ldb, TYPE *c, int64_t ldc, \ + const std::vector &dependencies) { \ + return omatadd(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, \ + lda, beta, b, ldb, c, ldc, dependencies); \ + } + +OMATADD_LAUNCHER_USM(float, cublasSgeam) +OMATADD_LAUNCHER_USM(double, cublasDgeam) +OMATADD_LAUNCHER_USM(std::complex, cublasCgeam) +OMATADD_LAUNCHER_USM(std::complex, cublasZgeam) + +#undef OMATADD_LAUNCHER_USM } // namespace column_major + namespace row_major { // Buffer APIs @@ -358,27 +448,62 @@ void gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose tra throw unimplemented("blas", "gemmt", "for row_major layout"); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} +template +void omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? n : m); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? m : n); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), + get_cublas_operation(trans), logical_m, logical_n, + (cuDataType *)&alpha, a_, lda, nullptr, nullptr, lda, b_, ldb); + }); + }); +} + +#define OMATCOPY_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { \ + omatcopy(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb); \ + } + +OMATCOPY_LAUNCHER(float, cublasSgeam) +OMATCOPY_LAUNCHER(double, cublasDgeam) +OMATCOPY_LAUNCHER(std::complex, cublasCgeam) +OMATCOPY_LAUNCHER(std::complex, cublasZgeam) + +#undef OMATCOPY_LAUNCHER + +template +void omatcopy2(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb) { + throw unimplemented("blas", "omatcopy2", ""); +} + +#define OMATCOPY2_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &b, int64_t ldb, int64_t strideb) { \ + omatcopy2(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, stridea, lda, b, \ + ldb, strideb); \ + } + +OMATCOPY2_LAUNCHER(float, "unimplemented") +OMATCOPY2_LAUNCHER(double, "unimplemented") +OMATCOPY2_LAUNCHER(std::complex, "unimplemented") +OMATCOPY2_LAUNCHER(std::complex, "unimplemented") +#undef OMATCOPY2_LAUNCHER void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { @@ -400,31 +525,43 @@ void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co throw unimplemented("blas", "imatcopy", "for row_major layout"); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, float beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, double beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +template +void omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, + T beta, sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + cublasStatus_t err; + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), n, m, (cuDataType *)&alpha, a_, + lda, (cuDataType *)&beta, b_, ldb, c_, ldc); + }); + }); +} + +#define OMATADD_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + TYPE alpha, sycl::buffer &a, int64_t lda, TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, \ + b, ldb, c, ldc); \ + } + +OMATADD_LAUNCHER(float, cublasSgeam) +OMATADD_LAUNCHER(double, cublasDgeam) +OMATADD_LAUNCHER(std::complex, cublasCgeam) +OMATADD_LAUNCHER(std::complex, cublasZgeam) + +#undef OMATADD_LAUNCHER // USM APIs @@ -490,31 +627,64 @@ sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transp throw unimplemented("blas", "gemmt", "for row_major layout"); } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, float *b, int64_t ldb, +template +sycl::event omatcopy(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, T *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, double *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} - -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + const int64_t logical_m = (trans == oneapi::mkl::transpose::nontrans ? n : m); + const int64_t logical_n = (trans == oneapi::mkl::transpose::nontrans ? m : n); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + cublasStatus_t err; + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(trans), + get_cublas_operation(trans), logical_m, logical_n, + (cuDataType *)&alpha, a_, lda, nullptr, nullptr, ldb, b_, ldb); + }); + }); + return done; +} + +#define OMATCOPY_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, TYPE *b, int64_t ldb, \ + const std::vector &dependencies) { \ + return omatcopy(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, \ + ldb, dependencies); \ + } + +OMATCOPY_LAUNCHER_USM(float, cublasSgeam) +OMATCOPY_LAUNCHER_USM(double, cublasDgeam) +OMATCOPY_LAUNCHER_USM(std::complex, cublasCgeam) +OMATCOPY_LAUNCHER_USM(std::complex, cublasZgeam) + +#undef OMATCOPY_LAUNCHER_USM + +template +sycl::event omatcopy2(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, int64_t stridea, T *b, + int64_t ldb, int64_t strideb, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy2", ""); +} + +#define OMATCOPY2_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, int64_t stridea, TYPE *b, int64_t ldb, \ + int64_t strideb, const std::vector &dependencies) { \ + return omatcopy2(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, trans, m, n, alpha, a, stridea, \ + lda, b, ldb, strideb, dependencies); \ + } + +OMATCOPY2_LAUNCHER_USM(float, "unimplemented") +OMATCOPY2_LAUNCHER_USM(double, "unimplemented") +OMATCOPY2_LAUNCHER_USM(std::complex, "unimplemented") +OMATCOPY2_LAUNCHER_USM(std::complex, "unimplemented") +#undef OMATCOPY2_LAUNCHER_USM sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, @@ -540,35 +710,44 @@ sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy", "for row_major layout"); } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, float beta, const float *b, - int64_t ldb, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, double beta, const double *b, - int64_t ldb, double *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +template +inline sycl::event omatadd(const char *func_name, Func func, sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, + T beta, const T *b, int64_t ldb, T *c, int64_t ldc, + const std::vector &dependencies) { + using cuDataType = typename CudaEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + cublasStatus_t err; + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_cublas_operation(transa), + get_cublas_operation(transb), n, m, (cuDataType *)&alpha, a_, + lda, (cuDataType *)&beta, b_, ldb, c_, ldc); + }); + }); + return done; +} + +#define OMATADD_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ + sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, TYPE alpha, const TYPE *a, int64_t lda, TYPE beta, \ + const TYPE *b, int64_t ldb, TYPE *c, int64_t ldc, \ + const std::vector &dependencies) { \ + return omatadd(#CUBLAS_ROUTINE, CUBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, \ + lda, beta, b, ldb, c, ldc, dependencies); \ + } + +OMATADD_LAUNCHER_USM(float, cublasSgeam) +OMATADD_LAUNCHER_USM(double, cublasDgeam) +OMATADD_LAUNCHER_USM(std::complex, cublasCgeam) +OMATADD_LAUNCHER_USM(std::complex, cublasZgeam) + +#undef OMATADD_LAUNCHER_USM } // namespace row_major } // namespace cublas diff --git a/src/blas/backends/cublas/cublas_helper.hpp b/src/blas/backends/cublas/cublas_helper.hpp index 0ee9930e3..0fe7e7c5a 100644 --- a/src/blas/backends/cublas/cublas_helper.hpp +++ b/src/blas/backends/cublas/cublas_helper.hpp @@ -35,6 +35,7 @@ #include "oneapi/mkl/types.hpp" #include "runtime_support_helper.hpp" +#include "dtype_string.hpp" namespace oneapi { namespace mkl { @@ -231,6 +232,56 @@ inline cublasSideMode_t get_cublas_side_mode(oneapi::mkl::side lr) { } } +template +inline cudaDataType_t get_cublas_datatype() { + static_assert(false); +} + +template <> +inline cudaDataType_t get_cublas_datatype<__half>() { + return CUDA_R_16F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_32F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_64F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_C_32F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_C_64F; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_8I; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_8U; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_32I; +} + +template <> +inline cudaDataType_t get_cublas_datatype() { + return CUDA_R_32U; +} + /*converting std::complex to cuComplex*/ /*converting sycl::half to __half*/ template diff --git a/src/blas/backends/cublas/cublas_level1.cpp b/src/blas/backends/cublas/cublas_level1.cpp index 49ed0db1b..5f7087727 100644 --- a/src/blas/backends/cublas/cublas_level1.cpp +++ b/src/blas/backends/cublas/cublas_level1.cpp @@ -473,10 +473,13 @@ inline void iamax(const char *func_name, Func func, sycl::queue &queue, int64_t cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); }); }); - // This requires to bring the data to host, copy it, and return it back to - // the device - result.template get_host_access(sycl::write_only)[0] = std::max( - (int64_t)int_res_buff.template get_host_access(sycl::read_only)[0] - 1, int64_t{ 0 }); + + queue.submit([&](sycl::handler &cgh) { + auto int_res_acc = int_res_buff.template get_access(cgh); + auto result_acc = result.template get_access(cgh); + cgh.single_task( + [=]() { result_acc[0] = std::max((int64_t)int_res_acc[0] - 1, (int64_t)0); }); + }); } #define IAMAX_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ @@ -556,8 +559,13 @@ inline void iamin(const char *func_name, Func func, sycl::queue &queue, int64_t cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); }); }); - result.template get_host_access(sycl::write_only)[0] = std::max( - (int64_t)int_res_buff.template get_host_access(sycl::read_only)[0] - 1, int64_t{ 0 }); + + queue.submit([&](sycl::handler &cgh) { + auto int_res_acc = int_res_buff.template get_access(cgh); + auto result_acc = result.template get_access(cgh); + cgh.single_task( + [=]() { result_acc[0] = std::max((int64_t)int_res_acc[0] - 1, (int64_t)0); }); + }); } #define IAMIN_LAUNCHER(TYPE, CUBLAS_ROUTINE) \ @@ -624,6 +632,8 @@ inline sycl::event asum(const char *func_name, Func func, sycl::queue &queue, in using cuDataType2 = typename CudaEquivalentType::Type; overflow_check(n, incx); + bool result_on_device = + sycl::get_pointer_type(result, queue.get_context()) == sycl::usm::alloc::device; auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -633,9 +643,15 @@ inline sycl::event asum(const char *func_name, Func func, sycl::queue &queue, in auto handle = sc.get_handle(queue); auto x_ = reinterpret_cast(x); auto res_ = reinterpret_cast(result); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + } cublasStatus_t err; // ASUM does not support negative index CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); + } }); }); return done; @@ -752,6 +768,21 @@ inline sycl::event rotg(const char *func_name, Func func, sycl::queue &queue, T1 T1 *s, const std::vector &dependencies) { using cuDataType1 = typename CudaEquivalentType::Type; using cuDataType2 = typename CudaEquivalentType::Type; + auto ctx = queue.get_context(); + bool results_on_device = (sycl::get_pointer_type(a, ctx) == sycl::usm::alloc::device || + sycl::get_pointer_type(b, ctx) == sycl::usm::alloc::device || + sycl::get_pointer_type(c, ctx) == sycl::usm::alloc::device || + sycl::get_pointer_type(s, ctx) == sycl::usm::alloc::device); + if (results_on_device) { + if (sycl::get_pointer_type(a, ctx) == sycl::usm::alloc::unknown || + sycl::get_pointer_type(b, ctx) == sycl::usm::alloc::unknown || + sycl::get_pointer_type(c, ctx) == sycl::usm::alloc::unknown || + sycl::get_pointer_type(s, ctx) == sycl::usm::alloc::unknown) { + throw oneapi::mkl::exception( + "blas", "rotg", + "If any pointer is only device accessible, all must be device accessible"); + } + } auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -763,8 +794,14 @@ inline sycl::event rotg(const char *func_name, Func func, sycl::queue &queue, T1 auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); auto s_ = reinterpret_cast(s); + if (results_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + } cublasStatus_t err; CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, a_, b_, c_, s_); + if (results_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); + } }); }); return done; @@ -856,6 +893,8 @@ inline sycl::event dot(const char *func_name, Func func, sycl::queue &queue, int const std::vector &dependencies) { using cuDataType = typename CudaEquivalentType::Type; overflow_check(n, incx, incy); + bool result_on_device = + sycl::get_pointer_type(result, queue.get_context()) == sycl::usm::alloc::device; auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -866,8 +905,14 @@ inline sycl::event dot(const char *func_name, Func func, sycl::queue &queue, int auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); auto res_ = reinterpret_cast(result); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + } cublasStatus_t err; CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, y_, incy, res_); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); + } }); }); return done; @@ -931,7 +976,9 @@ sycl::event sdsdot(sycl::queue &queue, int64_t n, float sb, const float *x, int6 const float *y, int64_t incy, float *result, const std::vector &dependencies) { overflow_check(n, incx, incy); - // cuBLAS does not support sdot so we need to mimic sdot. + bool result_on_device = + sycl::get_pointer_type(result, queue.get_context()) == sycl::usm::alloc::device; + // cuBLAS does not support sdsdot so we need to mimic sdot. auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -942,13 +989,32 @@ sycl::event sdsdot(sycl::queue &queue, int64_t n, float sb, const float *x, int6 auto x_ = reinterpret_cast(x); auto y_ = reinterpret_cast(y); auto res_ = reinterpret_cast(result); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + } cublasStatus_t err; CUBLAS_ERROR_FUNC_SYNC(cublasSdot, err, handle, n, x_, incx, y_, incy, res_); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); + } }); }); done.wait(); - result[0] = result[0] + sb; - return done; + if (result_on_device) { + // The following does copy device to host and then host to device + // just to adjust with sb constant. This is pretty inefficient, and + // should maybe be replaced with a sycl GPU kernel, but it duplicated what + // is done in the buffer API + float host_result; + queue.memcpy(&host_result, result, sizeof(float)).wait(); + host_result += sb; + auto last_ev = queue.memcpy(result, &host_result, sizeof(float)); + return last_ev; + } + else { + result[0] = result[0] + sb; + return done; + } } sycl::event dot(sycl::queue &queue, int64_t n, const float *x, int64_t incx, const float *y, @@ -960,6 +1026,24 @@ template inline sycl::event rotmg(const char *func_name, Func func, sycl::queue &queue, T *d1, T *d2, T *x1, T y1, T *param, const std::vector &dependencies) { using cuDataType = typename CudaEquivalentType::Type; + auto ctx = queue.get_context(); + bool results_on_device = (sycl::get_pointer_type(d1, ctx) == sycl::usm::alloc::device || + sycl::get_pointer_type(d2, ctx) == sycl::usm::alloc::device || + sycl::get_pointer_type(x1, ctx) == sycl::usm::alloc::device); + if (results_on_device) { + if (sycl::get_pointer_type(d1, ctx) == sycl::usm::alloc::unknown || + sycl::get_pointer_type(d2, ctx) == sycl::usm::alloc::unknown || + sycl::get_pointer_type(x1, ctx) == sycl::usm::alloc::unknown) { + throw oneapi::mkl::exception( + "blas", "rotmg", + "If any pointer is only device accessible, all must be device accessible"); + } + } + cuDataType *y1_; + if (results_on_device) { + y1_ = sycl::malloc_device(1, queue); + queue.memcpy(y1_, &y1, sizeof(cuDataType)).wait(); + } auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -970,12 +1054,24 @@ inline sycl::event rotmg(const char *func_name, Func func, sycl::queue &queue, T auto d1_ = reinterpret_cast(d1); auto d2_ = reinterpret_cast(d2); auto x1_ = reinterpret_cast(x1); - auto y1_ = reinterpret_cast(&y1); auto param_ = reinterpret_cast(param); cublasStatus_t err; - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); + if (results_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, d1_, d2_, x1_, y1_, param_); + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); + } + else { + auto y1_c = reinterpret_cast(&y1); + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, d1_, d2_, x1_, y1_c, param_); + } }); }); + if (results_on_device) { + done.wait(); + queue.memcpy(&y1, y1_, sizeof(cuDataType)).wait(); + sycl::free(y1_, queue); + } return done; } @@ -1001,7 +1097,15 @@ inline sycl::event iamax(const char *func_name, Func func, sycl::queue &queue, i // This change may cause failure as the result of integer overflow // based on the size. int int_res = 0; - int *int_res_p = &int_res; + int *int_res_p = nullptr; + bool result_on_device = + sycl::get_pointer_type(result, queue.get_context()) == sycl::usm::alloc::device; + if (result_on_device) { + int_res_p = sycl::malloc_device(1, queue); + } + else { + int_res_p = &int_res; + } auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1010,16 +1114,31 @@ inline sycl::event iamax(const char *func_name, Func func, sycl::queue &queue, i onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto x_ = reinterpret_cast(x); - auto int_res_p_ = reinterpret_cast(int_res_p); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + } cublasStatus_t err; // For negative incx, iamax returns 0. This behaviour is similar to that of // reference iamax. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_p_); + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_p); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); + } }); }); done.wait(); - result[0] = std::max((int64_t)(*int_res_p - 1), int64_t{ 0 }); - return done; + if (result_on_device) { + auto last_ev = queue.submit([&](sycl::handler &cgh) { + cgh.single_task([=]() { *result = std::max((int64_t)*int_res_p - 1, (int64_t)0); }); + }); + last_ev.wait(); + sycl::free(int_res_p, queue); + return last_ev; + } + else { + result[0] = std::max((int64_t)(*int_res_p - 1), int64_t{ 0 }); + return done; + } } #define IAMAX_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ @@ -1079,7 +1198,15 @@ inline sycl::event iamin(const char *func_name, Func func, sycl::queue &queue, i // This change may cause failure as the result of integer overflow // based on the size. int int_res = 0; - int *int_res_p = &int_res; + int *int_res_p = nullptr; + bool result_on_device = + sycl::get_pointer_type(result, queue.get_context()) == sycl::usm::alloc::device; + if (result_on_device) { + int_res_p = sycl::malloc_device(1, queue); + } + else { + int_res_p = &int_res; + } auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1088,16 +1215,31 @@ inline sycl::event iamin(const char *func_name, Func func, sycl::queue &queue, i onemkl_cublas_host_task(cgh, queue, [=](CublasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto x_ = reinterpret_cast(x); - auto int_res_p_ = reinterpret_cast(int_res_p); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + } cublasStatus_t err; // For negative incx, iamin returns 0. This behaviour is similar to that of // implemented iamin. - CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_p_); + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, incx, int_res_p); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); + } }); }); done.wait(); - result[0] = std::max((int64_t)(*int_res_p - 1), int64_t{ 0 }); - return done; + if (result_on_device) { + auto last_ev = queue.submit([&](sycl::handler &cgh) { + cgh.single_task([=]() { *result = std::max((int64_t)*int_res_p - 1, (int64_t)0); }); + }); + last_ev.wait(); + sycl::free(int_res_p, queue); + return last_ev; + } + else { + result[0] = std::max((int64_t)(*int_res_p - 1), int64_t{ 0 }); + return done; + } } #define IAMIN_LAUNCHER_USM(TYPE, CUBLAS_ROUTINE) \ @@ -1119,6 +1261,8 @@ inline sycl::event nrm2(const char *func_name, Func func, sycl::queue &queue, in using cuDataType2 = typename CudaEquivalentType::Type; overflow_check(n, incx); + bool result_on_device = + sycl::get_pointer_type(result, queue.get_context()) == sycl::usm::alloc::device; auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1128,9 +1272,15 @@ inline sycl::event nrm2(const char *func_name, Func func, sycl::queue &queue, in auto handle = sc.get_handle(queue); auto x_ = reinterpret_cast(x); auto res_ = reinterpret_cast(result); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_DEVICE); + } cublasStatus_t err; // NRM2 does not support negative index CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, handle, n, x_, std::abs(incx), res_); + if (result_on_device) { + cublasSetPointerMode(handle, CUBLAS_POINTER_MODE_HOST); + } }); }); return done; diff --git a/src/blas/backends/cublas/cublas_scope_handle.cpp b/src/blas/backends/cublas/cublas_scope_handle.cpp index f3e39ca11..05d1c1935 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.cpp +++ b/src/blas/backends/cublas/cublas_scope_handle.cpp @@ -42,10 +42,11 @@ CublasScopedContextHandler::CublasScopedContextHandler(sycl::queue queue, sycl:: : ih(ih), needToRecover_(false) { placedContext_ = new sycl::context(queue.get_context()); - auto device = queue.get_device(); - auto desired = sycl::get_native(*placedContext_); + auto cudaDevice = ih.get_native_device(); CUresult err; + CUcontext desired; CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_); + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); if (original_ != desired) { // Sets the desired context as the active one for the thread CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); @@ -87,8 +88,11 @@ void ContextCallback(void *userData) { } cublasHandle_t CublasScopedContextHandler::get_handle(const sycl::queue &queue) { - auto piPlacedContext_ = reinterpret_cast( - sycl::get_native(*placedContext_)); + auto cudaDevice = ih.get_native_device(); + CUresult cuErr; + CUcontext desired; + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice); + auto piPlacedContext_ = reinterpret_cast(desired); CUstream streamId = get_stream(queue); cublasStatus_t err; auto it = handle_helper.cublas_handle_mapper_.find(piPlacedContext_); diff --git a/src/blas/backends/cublas/cublas_scope_handle.hpp b/src/blas/backends/cublas/cublas_scope_handle.hpp index eec8cec03..7648130be 100644 --- a/src/blas/backends/cublas/cublas_scope_handle.hpp +++ b/src/blas/backends/cublas/cublas_scope_handle.hpp @@ -23,8 +23,10 @@ #else #include #endif -#if __has_include() +#if __has_include() +#if __SYCL_COMPILER_VERSION <= 20220930 #include +#endif #include #include #else diff --git a/src/blas/backends/cublas/cublas_task.hpp b/src/blas/backends/cublas/cublas_task.hpp index e1cf56a3b..e5cf0d7c2 100644 --- a/src/blas/backends/cublas/cublas_task.hpp +++ b/src/blas/backends/cublas/cublas_task.hpp @@ -1,3 +1,24 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Copyright (C) 2022 Heidelberg University, Engineering Mathematics and Computing Lab (EMCL) and Computing Centre (URZ) +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + #ifndef _MKL_BLAS_CUBLAS_TASK_HPP_ #define _MKL_BLAS_CUBLAS_TASK_HPP_ #include @@ -41,7 +62,6 @@ static inline void host_task_internal(H &cgh, sycl::queue queue, F f) { cgh.host_task([f, queue](sycl::interop_handle ih) { auto sc = CublasScopedContextHandler(queue, ih); f(sc); - sc.wait_stream(queue); }); } #endif diff --git a/src/blas/backends/cublas/cublas_wrappers.cpp b/src/blas/backends/cublas/cublas_wrappers.cpp index 2b545bf5f..ee5c7239f 100644 --- a/src/blas/backends/cublas/cublas_wrappers.cpp +++ b/src/blas/backends/cublas/cublas_wrappers.cpp @@ -205,6 +205,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::trsm_batch, oneapi::mkl::blas::cublas::column_major::trsm_batch, oneapi::mkl::blas::cublas::column_major::trsm_batch, @@ -233,6 +236,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::omatcopy, oneapi::mkl::blas::cublas::column_major::omatcopy, oneapi::mkl::blas::cublas::column_major::omatcopy, + oneapi::mkl::blas::cublas::column_major::omatcopy2, + oneapi::mkl::blas::cublas::column_major::omatcopy2, + oneapi::mkl::blas::cublas::column_major::omatcopy2, + oneapi::mkl::blas::cublas::column_major::omatcopy2, oneapi::mkl::blas::cublas::column_major::imatcopy, oneapi::mkl::blas::cublas::column_major::imatcopy, oneapi::mkl::blas::cublas::column_major::imatcopy, @@ -456,6 +463,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, + oneapi::mkl::blas::cublas::column_major::gemm_batch, oneapi::mkl::blas::cublas::column_major::gemmt, oneapi::mkl::blas::cublas::column_major::gemmt, oneapi::mkl::blas::cublas::column_major::gemmt, @@ -480,6 +493,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::column_major::omatcopy, oneapi::mkl::blas::cublas::column_major::omatcopy, oneapi::mkl::blas::cublas::column_major::omatcopy, + oneapi::mkl::blas::cublas::column_major::omatcopy2, + oneapi::mkl::blas::cublas::column_major::omatcopy2, + oneapi::mkl::blas::cublas::column_major::omatcopy2, + oneapi::mkl::blas::cublas::column_major::omatcopy2, oneapi::mkl::blas::cublas::column_major::imatcopy, oneapi::mkl::blas::cublas::column_major::imatcopy, oneapi::mkl::blas::cublas::column_major::imatcopy, @@ -678,6 +695,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::trsm_batch, oneapi::mkl::blas::cublas::row_major::trsm_batch, oneapi::mkl::blas::cublas::row_major::trsm_batch, @@ -706,6 +726,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::omatcopy, oneapi::mkl::blas::cublas::row_major::omatcopy, oneapi::mkl::blas::cublas::row_major::omatcopy, + oneapi::mkl::blas::cublas::row_major::omatcopy2, + oneapi::mkl::blas::cublas::row_major::omatcopy2, + oneapi::mkl::blas::cublas::row_major::omatcopy2, + oneapi::mkl::blas::cublas::row_major::omatcopy2, oneapi::mkl::blas::cublas::row_major::imatcopy, oneapi::mkl::blas::cublas::row_major::imatcopy, oneapi::mkl::blas::cublas::row_major::imatcopy, @@ -929,6 +953,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, + oneapi::mkl::blas::cublas::row_major::gemm_batch, oneapi::mkl::blas::cublas::row_major::gemmt, oneapi::mkl::blas::cublas::row_major::gemmt, oneapi::mkl::blas::cublas::row_major::gemmt, @@ -953,6 +983,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::cublas::row_major::omatcopy, oneapi::mkl::blas::cublas::row_major::omatcopy, oneapi::mkl::blas::cublas::row_major::omatcopy, + oneapi::mkl::blas::cublas::row_major::omatcopy2, + oneapi::mkl::blas::cublas::row_major::omatcopy2, + oneapi::mkl::blas::cublas::row_major::omatcopy2, + oneapi::mkl::blas::cublas::row_major::omatcopy2, oneapi::mkl::blas::cublas::row_major::imatcopy, oneapi::mkl::blas::cublas::row_major::imatcopy, oneapi::mkl::blas::cublas::row_major::imatcopy, diff --git a/src/blas/backends/mkl_common/mkl_batch.cxx b/src/blas/backends/mkl_common/mkl_batch.cxx index 0a204d5b7..6358a3922 100644 --- a/src/blas/backends/mkl_common/mkl_batch.cxx +++ b/src/blas/backends/mkl_common/mkl_batch.cxx @@ -182,6 +182,33 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t stride_b, beta, c, ldc, stride_c, batch_size); } +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", + "unsupported dtype combination: int8_t, int8_t, float, float"); +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { + blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size); +} + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, @@ -642,6 +669,33 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i stride_b, beta, c, ldc, stride_c, batch_size, dependencies); } +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a, + const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", + "unsupported dtype combination: int8_t, int8_t, float, float"); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, + std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, + stride_b, beta, c, ldc, stride_c, batch_size, dependencies); +} + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, int64_t *n, int64_t *k, float *alpha, const float **a, int64_t *lda, const float **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, @@ -689,6 +743,33 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, ldc, group_count, groupsize, dependencies); } +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda, + const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, group_count, groupsize, dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", + "unsupported dtype combination: int8_t, int8_t, float, float"); +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c, + int64_t *ldc, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + return blas_major::gemm_batch(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc, group_count, groupsize, dependencies); +} + sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, diff --git a/src/blas/backends/mkl_common/mkl_blas_backend.hpp b/src/blas/backends/mkl_common/mkl_blas_backend.hpp index 32cd0f263..d45208a6d 100644 --- a/src/blas/backends/mkl_common/mkl_blas_backend.hpp +++ b/src/blas/backends/mkl_common/mkl_blas_backend.hpp @@ -21,10 +21,50 @@ #include +#include "mkl_version.h" #include "oneapi/mkl/types.hpp" namespace oneapi { namespace mkl { + +template +class value_or_pointer { + T value_; + const T* ptr_; + +public: + // Constructor from value. Accepts not only type T but anything convertible to T. + template , int> = 0> + value_or_pointer(U value) : value_(value), + ptr_(nullptr) {} + + // Constructor from pointer, assumed to be device-accessible. + value_or_pointer(const T* ptr) : value_(T(0)), ptr_(ptr) {} + + bool fixed() const { + return ptr_ == nullptr; + } + + T get_fixed_value() const { + return value_; + } + + const T* get_pointer() const { + return ptr_; + } + + T get() const { + return ptr_ ? *ptr_ : value_; + } + + void make_device_accessible(sycl::queue& queue) { + if (!fixed() && + sycl::get_pointer_type(ptr_, queue.get_context()) == sycl::usm::alloc::unknown) { + *this = *ptr_; + } + } +}; + namespace blas { namespace column_major { diff --git a/src/blas/backends/mkl_common/mkl_blas_backend.hxx b/src/blas/backends/mkl_common/mkl_blas_backend.hxx index 4d49321ef..10e441bd7 100644 --- a/src/blas/backends/mkl_common/mkl_blas_backend.hxx +++ b/src/blas/backends/mkl_common/mkl_blas_backend.hxx @@ -51,13 +51,13 @@ void gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m void gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, - sycl::buffer &b, std::int64_t ldb, float beta, sycl::buffer &c, - std::int64_t ldc); + sycl::buffer &b, std::int64_t ldb, float beta, + sycl::buffer &c, std::int64_t ldc); void gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, - sycl::buffer &b, std::int64_t ldb, float beta, sycl::buffer &c, - std::int64_t ldc); + sycl::buffer &b, std::int64_t ldb, float beta, + sycl::buffer &c, std::int64_t ldc); void gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, std::int64_t lda, @@ -188,194 +188,209 @@ void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans // level 3, USM sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, - const float *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, std::int64_t k, value_or_pointer alpha, const float *a, + std::int64_t lda, const float *b, std::int64_t ldb, value_or_pointer beta, + float *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, double alpha, const double *a, std::int64_t lda, - const double *b, std::int64_t ldb, double beta, double *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, std::int64_t k, value_or_pointer alpha, const double *a, + std::int64_t lda, const double *b, std::int64_t ldb, value_or_pointer beta, + double *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, std::complex alpha, + std::int64_t n, std::int64_t k, value_or_pointer> alpha, const std::complex *a, std::int64_t lda, const std::complex *b, - std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldb, value_or_pointer> beta, std::complex *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, std::complex alpha, + std::int64_t n, std::int64_t k, value_or_pointer> alpha, const std::complex *a, std::int64_t lda, const std::complex *b, - std::int64_t ldb, std::complex beta, std::complex *c, - std::int64_t ldc, const std::vector &dependencies = {}); + std::int64_t ldb, value_or_pointer> beta, + std::complex *c, std::int64_t ldc, + const std::vector &dependencies = {}); sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, sycl::half alpha, const sycl::half *a, - std::int64_t lda, const sycl::half *b, std::int64_t ldb, sycl::half beta, - sycl::half *c, std::int64_t ldc, + std::int64_t n, std::int64_t k, value_or_pointer alpha, + const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, + value_or_pointer beta, sycl::half *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, - const sycl::half *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, std::int64_t k, value_or_pointer alpha, const sycl::half *a, + std::int64_t lda, const sycl::half *b, std::int64_t ldb, value_or_pointer beta, + float *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const bfloat16 *a, std::int64_t lda, - const bfloat16 *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, + std::int64_t n, std::int64_t k, value_or_pointer alpha, + const bfloat16 *a, std::int64_t lda, const bfloat16 *b, + std::int64_t ldb, value_or_pointer beta, float *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const bfloat16 *a, std::int64_t lda, - const bfloat16 *b, std::int64_t ldb, float beta, bfloat16 *c, std::int64_t ldc, + std::int64_t n, std::int64_t k, value_or_pointer alpha, + const bfloat16 *a, std::int64_t lda, const bfloat16 *b, + std::int64_t ldb, value_or_pointer beta, bfloat16 *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, - std::int64_t lda, const std::int8_t *b, std::int64_t ldb, float beta, - std::int32_t *c, std::int64_t ldc, + std::int64_t n, std::int64_t k, value_or_pointer alpha, const std::int8_t *a, + std::int64_t lda, const std::int8_t *b, std::int64_t ldb, + value_or_pointer beta, std::int32_t *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, - std::int64_t lda, const std::int8_t *b, std::int64_t ldb, float beta, float *c, - std::int64_t ldc, const std::vector &dependencies = {}); + std::int64_t n, std::int64_t k, value_or_pointer alpha, const std::int8_t *a, + std::int64_t lda, const std::int8_t *b, std::int64_t ldb, + value_or_pointer beta, float *c, std::int64_t ldc, + const std::vector &dependencies = {}); sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, - std::int64_t n, float alpha, const float *a, std::int64_t lda, const float *b, - std::int64_t ldb, float beta, float *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, value_or_pointer alpha, const float *a, std::int64_t lda, + const float *b, std::int64_t ldb, value_or_pointer beta, float *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, - std::int64_t n, double alpha, const double *a, std::int64_t lda, const double *b, - std::int64_t ldb, double beta, double *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, value_or_pointer alpha, const double *a, std::int64_t lda, + const double *b, std::int64_t ldb, value_or_pointer beta, double *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, - std::int64_t n, std::complex alpha, const std::complex *a, - std::int64_t lda, const std::complex *b, std::int64_t ldb, - std::complex beta, std::complex *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, value_or_pointer> beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, - std::int64_t n, std::complex alpha, const std::complex *a, - std::int64_t lda, const std::complex *b, std::int64_t ldb, - std::complex beta, std::complex *c, std::int64_t ldc, + std::int64_t n, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, value_or_pointer> beta, + std::complex *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event hemm(sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, - std::int64_t n, std::complex alpha, const std::complex *a, - std::int64_t lda, const std::complex *b, std::int64_t ldb, - std::complex beta, std::complex *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, value_or_pointer> beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event hemm(sycl::queue &queue, side left_right, uplo upper_lower, std::int64_t m, - std::int64_t n, std::complex alpha, const std::complex *a, - std::int64_t lda, const std::complex *b, std::int64_t ldb, - std::complex beta, std::complex *c, std::int64_t ldc, + std::int64_t n, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, value_or_pointer> beta, + std::complex *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, float alpha, const float *a, std::int64_t lda, float beta, - float *c, std::int64_t ldc, const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer alpha, const float *a, std::int64_t lda, + value_or_pointer beta, float *c, std::int64_t ldc, + const std::vector &dependencies = {}); sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, double alpha, const double *a, std::int64_t lda, double beta, - double *c, std::int64_t ldc, const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer alpha, const double *a, std::int64_t lda, + value_or_pointer beta, double *c, std::int64_t ldc, + const std::vector &dependencies = {}); sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, std::complex alpha, const std::complex *a, - std::int64_t lda, std::complex beta, std::complex *c, - std::int64_t ldc, const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, + value_or_pointer> beta, std::complex *c, std::int64_t ldc, + const std::vector &dependencies = {}); sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, std::complex alpha, const std::complex *a, - std::int64_t lda, std::complex beta, std::complex *c, + std::int64_t k, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, + value_or_pointer> beta, std::complex *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event herk(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, float alpha, const std::complex *a, std::int64_t lda, - float beta, std::complex *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer alpha, const std::complex *a, + std::int64_t lda, value_or_pointer beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event herk(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, double alpha, const std::complex *a, std::int64_t lda, - double beta, std::complex *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer alpha, const std::complex *a, + std::int64_t lda, value_or_pointer beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, - std::int64_t ldb, float beta, float *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer alpha, const float *a, std::int64_t lda, + const float *b, std::int64_t ldb, value_or_pointer beta, float *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, double alpha, const double *a, std::int64_t lda, const double *b, - std::int64_t ldb, double beta, double *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer alpha, const double *a, std::int64_t lda, + const double *b, std::int64_t ldb, value_or_pointer beta, double *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, std::complex alpha, const std::complex *a, - std::int64_t lda, const std::complex *b, std::int64_t ldb, - std::complex beta, std::complex *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, value_or_pointer> beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, std::complex alpha, const std::complex *a, - std::int64_t lda, const std::complex *b, std::int64_t ldb, - std::complex beta, std::complex *c, std::int64_t ldc, + std::int64_t k, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, value_or_pointer> beta, + std::complex *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event her2k(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, std::complex alpha, const std::complex *a, - std::int64_t lda, const std::complex *b, std::int64_t ldb, float beta, - std::complex *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, value_or_pointer beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event her2k(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, std::complex alpha, const std::complex *a, - std::int64_t lda, const std::complex *b, std::int64_t ldb, double beta, - std::complex *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, value_or_pointer beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, float alpha, const float *a, - std::int64_t lda, float *b, std::int64_t ldb, + diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer alpha, + const float *a, std::int64_t lda, float *b, std::int64_t ldb, const std::vector &dependencies = {}); sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, double alpha, const double *a, - std::int64_t lda, double *b, std::int64_t ldb, + diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer alpha, + const double *a, std::int64_t lda, double *b, std::int64_t ldb, const std::vector &dependencies = {}); sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *a, std::int64_t lda, std::complex *b, - std::int64_t ldb, const std::vector &dependencies = {}); + diag unit_diag, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, + const std::vector &dependencies = {}); sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *a, std::int64_t lda, std::complex *b, - std::int64_t ldb, const std::vector &dependencies = {}); + diag unit_diag, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, + const std::vector &dependencies = {}); sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, float alpha, const float *a, - std::int64_t lda, float *b, std::int64_t ldb, + diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer alpha, + const float *a, std::int64_t lda, float *b, std::int64_t ldb, const std::vector &dependencies = {}); sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, double alpha, const double *a, - std::int64_t lda, double *b, std::int64_t ldb, + diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer alpha, + const double *a, std::int64_t lda, double *b, std::int64_t ldb, const std::vector &dependencies = {}); sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *a, std::int64_t lda, std::complex *b, - std::int64_t ldb, const std::vector &dependencies = {}); + diag unit_diag, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, + const std::vector &dependencies = {}); sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *a, std::int64_t lda, std::complex *b, - std::int64_t ldb, const std::vector &dependencies = {}); + diag unit_diag, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, + const std::vector &dependencies = {}); // level 2, buffer @@ -658,195 +673,214 @@ void trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, // level 2, USM -sycl::event gemv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, - const float *a, std::int64_t lda, const float *x, std::int64_t incx, float beta, - float *y, std::int64_t incy, const std::vector &dependencies = {}); - -sycl::event gemv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, double alpha, - const double *a, std::int64_t lda, const double *x, std::int64_t incx, double beta, - double *y, std::int64_t incy, const std::vector &dependencies = {}); +sycl::event gemv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + value_or_pointer alpha, const float *a, std::int64_t lda, const float *x, + std::int64_t incx, value_or_pointer beta, float *y, std::int64_t incy, + const std::vector &dependencies = {}); sycl::event gemv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, const std::complex *a, std::int64_t lda, - const std::complex *x, std::int64_t incx, std::complex beta, - std::complex *y, std::int64_t incy, + value_or_pointer alpha, const double *a, std::int64_t lda, const double *x, + std::int64_t incx, value_or_pointer beta, double *y, std::int64_t incy, const std::vector &dependencies = {}); sycl::event gemv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, const std::complex *a, std::int64_t lda, - const std::complex *x, std::int64_t incx, std::complex beta, - std::complex *y, std::int64_t incy, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, std::int64_t incy, const std::vector &dependencies = {}); +sycl::event gemv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, + std::int64_t incy, const std::vector &dependencies = {}); + sycl::event gbmv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::int64_t kl, std::int64_t ku, float alpha, const float *a, std::int64_t lda, - const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, - const std::vector &dependencies = {}); + std::int64_t kl, std::int64_t ku, value_or_pointer alpha, const float *a, + std::int64_t lda, const float *x, std::int64_t incx, value_or_pointer beta, + float *y, std::int64_t incy, const std::vector &dependencies = {}); sycl::event gbmv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::int64_t kl, std::int64_t ku, double alpha, const double *a, std::int64_t lda, - const double *x, std::int64_t incx, double beta, double *y, std::int64_t incy, - const std::vector &dependencies = {}); + std::int64_t kl, std::int64_t ku, value_or_pointer alpha, const double *a, + std::int64_t lda, const double *x, std::int64_t incx, value_or_pointer beta, + double *y, std::int64_t incy, const std::vector &dependencies = {}); sycl::event gbmv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::int64_t kl, std::int64_t ku, std::complex alpha, + std::int64_t kl, std::int64_t ku, value_or_pointer> alpha, const std::complex *a, std::int64_t lda, const std::complex *x, - std::int64_t incx, std::complex beta, std::complex *y, + std::int64_t incx, value_or_pointer> beta, std::complex *y, std::int64_t incy, const std::vector &dependencies = {}); sycl::event gbmv(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::int64_t kl, std::int64_t ku, std::complex alpha, + std::int64_t kl, std::int64_t ku, value_or_pointer> alpha, const std::complex *a, std::int64_t lda, const std::complex *x, - std::int64_t incx, std::complex beta, std::complex *y, - std::int64_t incy, const std::vector &dependencies = {}); + std::int64_t incx, value_or_pointer> beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies = {}); -sycl::event ger(sycl::queue &queue, std::int64_t m, std::int64_t n, float alpha, const float *x, - std::int64_t incx, const float *y, std::int64_t incy, float *a, std::int64_t lda, - const std::vector &dependencies = {}); +sycl::event ger(sycl::queue &queue, std::int64_t m, std::int64_t n, value_or_pointer alpha, + const float *x, std::int64_t incx, const float *y, std::int64_t incy, float *a, + std::int64_t lda, const std::vector &dependencies = {}); -sycl::event ger(sycl::queue &queue, std::int64_t m, std::int64_t n, double alpha, const double *x, - std::int64_t incx, const double *y, std::int64_t incy, double *a, std::int64_t lda, - const std::vector &dependencies = {}); +sycl::event ger(sycl::queue &queue, std::int64_t m, std::int64_t n, value_or_pointer alpha, + const double *x, std::int64_t incx, const double *y, std::int64_t incy, double *a, + std::int64_t lda, const std::vector &dependencies = {}); -sycl::event gerc(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex *y, - std::int64_t incy, std::complex *a, std::int64_t lda, +sycl::event gerc(sycl::queue &queue, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event gerc(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex *y, - std::int64_t incy, std::complex *a, std::int64_t lda, +sycl::event gerc(sycl::queue &queue, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event geru(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex *y, - std::int64_t incy, std::complex *a, std::int64_t lda, +sycl::event geru(sycl::queue &queue, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event geru(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex *y, - std::int64_t incy, std::complex *a, std::int64_t lda, +sycl::event geru(sycl::queue &queue, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *a, std::int64_t lda, const std::vector &dependencies = {}); sycl::event hbmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, - std::complex alpha, const std::complex *a, std::int64_t lda, - const std::complex *x, std::int64_t incx, std::complex beta, - std::complex *y, std::int64_t incy, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, std::int64_t incy, const std::vector &dependencies = {}); sycl::event hbmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, - std::complex alpha, const std::complex *a, std::int64_t lda, - const std::complex *x, std::int64_t incx, std::complex beta, - std::complex *y, std::int64_t incy, - const std::vector &dependencies = {}); - -sycl::event hemv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, - const std::complex *a, std::int64_t lda, const std::complex *x, - std::int64_t incx, std::complex beta, std::complex *y, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event hemv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, - const std::complex *a, std::int64_t lda, const std::complex *x, - std::int64_t incx, std::complex beta, std::complex *y, +sycl::event hemv(sycl::queue &queue, uplo upper_lower, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, std::int64_t incy, + const std::vector &dependencies = {}); + +sycl::event hemv(sycl::queue &queue, uplo upper_lower, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event her(sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, +sycl::event her(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, const std::complex *x, std::int64_t incx, std::complex *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event her(sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, +sycl::event her(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, const std::complex *x, std::int64_t incx, std::complex *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event her2(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex *y, - std::int64_t incy, std::complex *a, std::int64_t lda, +sycl::event her2(sycl::queue &queue, uplo upper_lower, std::int64_t n, + value_or_pointer> alpha, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event her2(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex *y, - std::int64_t incy, std::complex *a, std::int64_t lda, +sycl::event her2(sycl::queue &queue, uplo upper_lower, std::int64_t n, + value_or_pointer> alpha, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event hpmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, - const std::complex *a, const std::complex *x, std::int64_t incx, - std::complex beta, std::complex *y, std::int64_t incy, +sycl::event hpmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event hpmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, - const std::complex *a, const std::complex *x, std::int64_t incx, - std::complex beta, std::complex *y, std::int64_t incy, - const std::vector &dependencies = {}); +sycl::event hpmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, + std::int64_t incy, const std::vector &dependencies = {}); -sycl::event hpr(sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, +sycl::event hpr(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, const std::complex *x, std::int64_t incx, std::complex *a, const std::vector &dependencies = {}); -sycl::event hpr(sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, +sycl::event hpr(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, const std::complex *x, std::int64_t incx, std::complex *a, const std::vector &dependencies = {}); -sycl::event hpr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex *y, - std::int64_t incy, std::complex *a, - const std::vector &dependencies = {}); +sycl::event hpr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, + value_or_pointer> alpha, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *a, const std::vector &dependencies = {}); -sycl::event hpr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex *y, - std::int64_t incy, std::complex *a, - const std::vector &dependencies = {}); +sycl::event hpr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, + value_or_pointer> alpha, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *a, const std::vector &dependencies = {}); -sycl::event sbmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, float alpha, - const float *a, std::int64_t lda, const float *x, std::int64_t incx, float beta, - float *y, std::int64_t incy, const std::vector &dependencies = {}); +sycl::event sbmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + value_or_pointer alpha, const float *a, std::int64_t lda, const float *x, + std::int64_t incx, value_or_pointer beta, float *y, std::int64_t incy, + const std::vector &dependencies = {}); -sycl::event sbmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, double alpha, - const double *a, std::int64_t lda, const double *x, std::int64_t incx, double beta, - double *y, std::int64_t incy, const std::vector &dependencies = {}); +sycl::event sbmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, std::int64_t k, + value_or_pointer alpha, const double *a, std::int64_t lda, const double *x, + std::int64_t incx, value_or_pointer beta, double *y, std::int64_t incy, + const std::vector &dependencies = {}); -sycl::event symv(sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, const float *a, - std::int64_t lda, const float *x, std::int64_t incx, float beta, float *y, - std::int64_t incy, const std::vector &dependencies = {}); +sycl::event symv(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const float *a, std::int64_t lda, const float *x, std::int64_t incx, + value_or_pointer beta, float *y, std::int64_t incy, + const std::vector &dependencies = {}); -sycl::event symv(sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, - const double *a, std::int64_t lda, const double *x, std::int64_t incx, double beta, - double *y, std::int64_t incy, const std::vector &dependencies = {}); +sycl::event symv(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const double *a, std::int64_t lda, const double *x, std::int64_t incx, + value_or_pointer beta, double *y, std::int64_t incy, + const std::vector &dependencies = {}); -sycl::event syr(sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, const float *x, - std::int64_t incx, float *a, std::int64_t lda, +sycl::event syr(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const float *x, std::int64_t incx, float *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event syr(sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, const double *x, - std::int64_t incx, double *a, std::int64_t lda, +sycl::event syr(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const double *x, std::int64_t incx, double *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event syr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, const float *x, - std::int64_t incx, const float *y, std::int64_t incy, float *a, std::int64_t lda, - const std::vector &dependencies = {}); +sycl::event syr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const float *x, std::int64_t incx, const float *y, std::int64_t incy, float *a, + std::int64_t lda, const std::vector &dependencies = {}); -sycl::event syr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, +sycl::event syr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, const double *x, std::int64_t incx, const double *y, std::int64_t incy, double *a, std::int64_t lda, const std::vector &dependencies = {}); -sycl::event spmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, const float *a, - const float *x, std::int64_t incx, float beta, float *y, std::int64_t incy, - const std::vector &dependencies = {}); +sycl::event spmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const float *a, const float *x, std::int64_t incx, value_or_pointer beta, + float *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event spmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, - const double *a, const double *x, std::int64_t incx, double beta, double *y, - std::int64_t incy, const std::vector &dependencies = {}); +sycl::event spmv(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const double *a, const double *x, std::int64_t incx, value_or_pointer beta, + double *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event spr(sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, const float *x, - std::int64_t incx, float *a, const std::vector &dependencies = {}); +sycl::event spr(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const float *x, std::int64_t incx, float *a, + const std::vector &dependencies = {}); -sycl::event spr(sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, const double *x, - std::int64_t incx, double *a, const std::vector &dependencies = {}); +sycl::event spr(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const double *x, std::int64_t incx, double *a, + const std::vector &dependencies = {}); -sycl::event spr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, float alpha, const float *x, - std::int64_t incx, const float *y, std::int64_t incy, float *a, +sycl::event spr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, + const float *x, std::int64_t incx, const float *y, std::int64_t incy, float *a, const std::vector &dependencies = {}); -sycl::event spr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, double alpha, +sycl::event spr2(sycl::queue &queue, uplo upper_lower, std::int64_t n, value_or_pointer alpha, const double *x, std::int64_t incx, const double *y, std::int64_t incy, double *a, const std::vector &dependencies = {}); @@ -973,28 +1007,32 @@ void dotu(sycl::queue &queue, std::int64_t n, sycl::buffer, sycl::buffer, 1> &result); void iamax(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, - sycl::buffer &result); + sycl::buffer &result, index_base base=index_base::zero); void iamax(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, - sycl::buffer &result); + sycl::buffer &result, index_base base=index_base::zero); void iamax(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, - std::int64_t incx, sycl::buffer &result); + std::int64_t incx, sycl::buffer &result, + index_base base=index_base::zero); void iamax(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, - std::int64_t incx, sycl::buffer &result); + std::int64_t incx, sycl::buffer &result, + index_base base=index_base::zero); void iamin(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, - sycl::buffer &result); + sycl::buffer &result, index_base base=index_base::zero); void iamin(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, - sycl::buffer &result); + sycl::buffer &result, index_base base=index_base::zero); void iamin(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, - std::int64_t incx, sycl::buffer &result); + std::int64_t incx, sycl::buffer &result, + index_base base=index_base::zero); void iamin(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, - std::int64_t incx, sycl::buffer &result); + std::int64_t incx, sycl::buffer &result, + index_base base=index_base::zero); void asum(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, std::int64_t incx, sycl::buffer &result); @@ -1203,38 +1241,39 @@ sycl::event asum(sycl::queue &queue, std::int64_t n, const float *x, std::int64_ sycl::event asum(sycl::queue &queue, std::int64_t n, const double *x, std::int64_t incx, double *result, const std::vector &dependencies = {}); -sycl::event axpy(sycl::queue &queue, std::int64_t n, float alpha, const float *x, std::int64_t incx, - float *y, std::int64_t incy, const std::vector &dependencies = {}); +sycl::event axpy(sycl::queue &queue, std::int64_t n, value_or_pointer alpha, const float *x, + std::int64_t incx, float *y, std::int64_t incy, + const std::vector &dependencies = {}); -sycl::event axpy(sycl::queue &queue, std::int64_t n, double alpha, const double *x, +sycl::event axpy(sycl::queue &queue, std::int64_t n, value_or_pointer alpha, const double *x, std::int64_t incx, double *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event axpy(sycl::queue &queue, std::int64_t n, std::complex alpha, +sycl::event axpy(sycl::queue &queue, std::int64_t n, value_or_pointer> alpha, const std::complex *x, std::int64_t incx, std::complex *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event axpy(sycl::queue &queue, std::int64_t n, std::complex alpha, +sycl::event axpy(sycl::queue &queue, std::int64_t n, value_or_pointer> alpha, const std::complex *x, std::int64_t incx, std::complex *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event axpby(sycl::queue &queue, std::int64_t n, float alpha, const float *x, - std::int64_t incx, const float beta, float *y, std::int64_t incy, +sycl::event axpby(sycl::queue &queue, std::int64_t n, value_or_pointer alpha, const float *x, + std::int64_t incx, value_or_pointer beta, float *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event axpby(sycl::queue &queue, std::int64_t n, double alpha, const double *x, - std::int64_t incx, const double beta, double *y, std::int64_t incy, +sycl::event axpby(sycl::queue &queue, std::int64_t n, value_or_pointer alpha, const double *x, + std::int64_t incx, value_or_pointer beta, double *y, std::int64_t incy, const std::vector &dependencies = {}); -sycl::event axpby(sycl::queue &queue, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex beta, - std::complex *y, std::int64_t incy, - const std::vector &dependencies = {}); +sycl::event axpby(sycl::queue &queue, std::int64_t n, value_or_pointer> alpha, + const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, + std::int64_t incy, const std::vector &dependencies = {}); -sycl::event axpby(sycl::queue &queue, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, const std::complex beta, - std::complex *y, std::int64_t incy, - const std::vector &dependencies = {}); +sycl::event axpby(sycl::queue &queue, std::int64_t n, value_or_pointer> alpha, + const std::complex *x, std::int64_t incx, + value_or_pointer> beta, std::complex *y, + std::int64_t incy, const std::vector &dependencies = {}); sycl::event copy(sycl::queue &queue, std::int64_t n, const float *x, std::int64_t incx, float *y, std::int64_t incy, const std::vector &dependencies = {}); @@ -1281,19 +1320,19 @@ sycl::event nrm2(sycl::queue &queue, std::int64_t n, const double *x, std::int64 double *result, const std::vector &dependencies = {}); sycl::event rot(sycl::queue &queue, std::int64_t n, std::complex *x, std::int64_t incx, - std::complex *y, std::int64_t incy, float c, float s, - const std::vector &dependencies = {}); + std::complex *y, std::int64_t incy, value_or_pointer c, + value_or_pointer s, const std::vector &dependencies = {}); sycl::event rot(sycl::queue &queue, std::int64_t n, std::complex *x, std::int64_t incx, - std::complex *y, std::int64_t incy, double c, double s, - const std::vector &dependencies = {}); + std::complex *y, std::int64_t incy, value_or_pointer c, + value_or_pointer s, const std::vector &dependencies = {}); sycl::event rot(sycl::queue &queue, std::int64_t n, float *x, std::int64_t incx, float *y, - std::int64_t incy, float c, float s, + std::int64_t incy, value_or_pointer c, value_or_pointer s, const std::vector &dependencies = {}); sycl::event rot(sycl::queue &queue, std::int64_t n, double *x, std::int64_t incx, double *y, - std::int64_t incy, double c, double s, + std::int64_t incy, value_or_pointer c, value_or_pointer s, const std::vector &dependencies = {}); sycl::event rotg(sycl::queue &queue, float *a, float *b, float *c, float *s, @@ -1308,15 +1347,6 @@ sycl::event rotg(sycl::queue &queue, std::complex *a, std::complex sycl::event rotg(sycl::queue &queue, std::complex *a, std::complex *b, double *c, std::complex *s, const std::vector &dependencies = {}); -#if defined(INTEL_MKL_VERSION) && (INTEL_MKL_VERSION < 20230000) -sycl::event rotm(sycl::queue &queue, std::int64_t n, float *x, std::int64_t incx, float *y, - std::int64_t incy, float *param, - const std::vector &dependencies = {}); - -sycl::event rotm(sycl::queue &queue, std::int64_t n, double *x, std::int64_t incx, double *y, - std::int64_t incy, double *param, - const std::vector &dependencies = {}); -#else sycl::event rotm(sycl::queue &queue, std::int64_t n, float *x, std::int64_t incx, float *y, std::int64_t incy, const float *param, const std::vector &dependencies = {}); @@ -1324,34 +1354,30 @@ sycl::event rotm(sycl::queue &queue, std::int64_t n, float *x, std::int64_t incx sycl::event rotm(sycl::queue &queue, std::int64_t n, double *x, std::int64_t incx, double *y, std::int64_t incy, const double *param, const std::vector &dependencies = {}); -#endif -sycl::event rotmg(sycl::queue &queue, float *d1, float *d2, float *x1, float y1, float *param, - const std::vector &dependencies = {}); - -sycl::event rotmg(sycl::queue &queue, double *d1, double *d2, double *x1, double y1, double *param, - const std::vector &dependencies = {}); +sycl::event rotmg(sycl::queue &queue, float *d1, float *d2, float *x1, value_or_pointer y1, + float *param, const std::vector &dependencies = {}); -sycl::event scal(sycl::queue &queue, std::int64_t n, float alpha, float *x, std::int64_t incx, - const std::vector &dependencies = {}); +sycl::event rotmg(sycl::queue &queue, double *d1, double *d2, double *x1, value_or_pointer y1, + double *param, const std::vector &dependencies = {}); -sycl::event scal(sycl::queue &queue, std::int64_t n, double alpha, double *x, std::int64_t incx, - const std::vector &dependencies = {}); - -sycl::event scal(sycl::queue &queue, std::int64_t n, std::complex alpha, - std::complex *x, std::int64_t incx, - const std::vector &dependencies = {}); - -sycl::event scal(sycl::queue &queue, std::int64_t n, std::complex alpha, - std::complex *x, std::int64_t incx, - const std::vector &dependencies = {}); +#define ONEMKL_DECLARE_SCAL(T, Ts) \ + sycl::event scal(sycl::queue &queue, std::int64_t n, value_or_pointer alpha, T *x, \ + std::int64_t incx, const std::vector &dependencies = {}); +ONEMKL_DECLARE_SCAL(float, float) +ONEMKL_DECLARE_SCAL(double, double) +ONEMKL_DECLARE_SCAL(std::complex, std::complex) +ONEMKL_DECLARE_SCAL(std::complex, std::complex) +ONEMKL_DECLARE_SCAL(std::complex, float) +ONEMKL_DECLARE_SCAL(std::complex, double) sycl::event scal(sycl::queue &queue, std::int64_t n, float alpha, std::complex *x, std::int64_t incx, const std::vector &dependencies = {}); - sycl::event scal(sycl::queue &queue, std::int64_t n, double alpha, std::complex *x, std::int64_t incx, const std::vector &dependencies = {}); +#undef ONEMKL_DECLARE_SCAL + sycl::event swap(sycl::queue &queue, std::int64_t n, float *x, std::int64_t incx, float *y, std::int64_t incy, const std::vector &dependencies = {}); @@ -1419,54 +1445,55 @@ void gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset of // extensions, USM sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, - std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, - const float *b, std::int64_t ldb, float beta, float *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, std::int64_t k, value_or_pointer alpha, const float *a, + std::int64_t lda, const float *b, std::int64_t ldb, value_or_pointer beta, + float *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, - std::int64_t n, std::int64_t k, double alpha, const double *a, std::int64_t lda, - const double *b, std::int64_t ldb, double beta, double *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, std::int64_t k, value_or_pointer alpha, const double *a, + std::int64_t lda, const double *b, std::int64_t ldb, value_or_pointer beta, + double *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, - std::int64_t n, std::int64_t k, std::complex alpha, + std::int64_t n, std::int64_t k, value_or_pointer> alpha, const std::complex *a, std::int64_t lda, const std::complex *b, - std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldb, value_or_pointer> beta, std::complex *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, - std::int64_t n, std::int64_t k, std::complex alpha, + std::int64_t n, std::int64_t k, value_or_pointer> alpha, const std::complex *a, std::int64_t lda, const std::complex *b, - std::int64_t ldb, std::complex beta, std::complex *c, - std::int64_t ldc, const std::vector &dependencies = {}); + std::int64_t ldb, value_or_pointer> beta, + std::complex *c, std::int64_t ldc, + const std::vector &dependencies = {}); sycl::event gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset offsetc, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + std::int64_t m, std::int64_t n, std::int64_t k, value_or_pointer alpha, const std::int8_t *a, std::int64_t lda, std::int8_t ao, const std::uint8_t *b, - std::int64_t ldb, std::uint8_t bo, float beta, std::int32_t *c, + std::int64_t ldb, std::uint8_t bo, value_or_pointer beta, std::int32_t *c, std::int64_t ldc, const std::int32_t *co, const std::vector &dependencies = {}); sycl::event gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset offsetc, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + std::int64_t m, std::int64_t n, std::int64_t k, value_or_pointer alpha, const std::int8_t *a, std::int64_t lda, std::int8_t ao, const std::int8_t *b, - std::int64_t ldb, std::int8_t bo, float beta, std::int32_t *c, + std::int64_t ldb, std::int8_t bo, value_or_pointer beta, std::int32_t *c, std::int64_t ldc, const std::int32_t *co, const std::vector &dependencies = {}); sycl::event gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset offsetc, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + std::int64_t m, std::int64_t n, std::int64_t k, value_or_pointer alpha, const std::uint8_t *a, std::int64_t lda, std::uint8_t ao, - const std::int8_t *b, std::int64_t ldb, std::int8_t bo, float beta, - std::int32_t *c, std::int64_t ldc, const std::int32_t *co, - const std::vector &dependencies = {}); + const std::int8_t *b, std::int64_t ldb, std::int8_t bo, + value_or_pointer beta, std::int32_t *c, std::int64_t ldc, + const std::int32_t *co, const std::vector &dependencies = {}); sycl::event gemm_bias(sycl::queue &queue, transpose transa, transpose transb, offset offsetc, - std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + std::int64_t m, std::int64_t n, std::int64_t k, value_or_pointer alpha, const std::uint8_t *a, std::int64_t lda, std::uint8_t ao, - const std::uint8_t *b, std::int64_t ldb, std::uint8_t bo, float beta, - std::int32_t *c, std::int64_t ldc, const std::int32_t *co, - const std::vector &dependencies = {}); + const std::uint8_t *b, std::int64_t ldb, std::uint8_t bo, + value_or_pointer beta, std::int32_t *c, std::int64_t ldc, + const std::int32_t *co, const std::vector &dependencies = {}); // batch, buffer @@ -1742,6 +1769,24 @@ void omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, sycl::buffer, 1> &b, std::int64_t ldb); +void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &b, std::int64_t ldb, std::int64_t strideb); + +void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &b, std::int64_t ldb, std::int64_t strideb); + +void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t strideb); + +void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb); + void imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb); @@ -1806,25 +1851,27 @@ sycl::event syrk_batch(sycl::queue &queue, const uplo *upper_lower, const transp const std::vector &dependencies = {}); sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, float alpha, const float *a, std::int64_t lda, - std::int64_t stride_a, float beta, float *c, std::int64_t ldc, + std::int64_t k, value_or_pointer alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, value_or_pointer beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, double alpha, const double *a, std::int64_t lda, - std::int64_t stride_a, double beta, double *c, std::int64_t ldc, - std::int64_t stride_c, std::int64_t batch_size, + std::int64_t k, value_or_pointer alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, value_or_pointer beta, + double *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, std::complex alpha, const std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::complex beta, - std::complex *c, std::int64_t ldc, std::int64_t stride_c, - std::int64_t batch_size, const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + value_or_pointer> beta, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, std::int64_t n, - std::int64_t k, std::complex alpha, const std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::complex beta, - std::complex *c, std::int64_t ldc, std::int64_t stride_c, - std::int64_t batch_size, const std::vector &dependencies = {}); + std::int64_t k, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + value_or_pointer> beta, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies = {}); sycl::event copy_batch(sycl::queue &queue, std::int64_t n, const float *x, std::int64_t incx, std::int64_t stridex, float *y, std::int64_t incy, std::int64_t stridey, @@ -1917,28 +1964,32 @@ sycl::event dgmm_batch(sycl::queue &queue, const side *left_right, const std::in const std::vector &dependencies = {}); sycl::event gemv_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - float alpha, const float *a, std::int64_t lda, std::int64_t stridea, - const float *x, std::int64_t incx, std::int64_t stridex, float beta, - float *y, std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, + value_or_pointer alpha, const float *a, std::int64_t lda, + std::int64_t stridea, const float *x, std::int64_t incx, + std::int64_t stridex, value_or_pointer beta, float *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemv_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - double alpha, const double *a, std::int64_t lda, std::int64_t stridea, - const double *x, std::int64_t incx, std::int64_t stridex, double beta, - double *y, std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, + value_or_pointer alpha, const double *a, std::int64_t lda, + std::int64_t stridea, const double *x, std::int64_t incx, + std::int64_t stridex, value_or_pointer beta, double *y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemv_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, const std::complex *a, std::int64_t lda, - std::int64_t stridea, const std::complex *x, std::int64_t incx, - std::int64_t stridex, std::complex beta, std::complex *y, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, const std::complex *x, + std::int64_t incx, std::int64_t stridex, + value_or_pointer> beta, std::complex *y, std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemv_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, const std::complex *a, std::int64_t lda, - std::int64_t stridea, const std::complex *x, std::int64_t incx, - std::int64_t stridex, std::complex beta, std::complex *y, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, const std::complex *x, + std::int64_t incx, std::int64_t stridex, + value_or_pointer> beta, std::complex *y, std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies = {}); @@ -1998,25 +2049,26 @@ sycl::event axpy_batch(sycl::queue &queue, const std::int64_t *n, const std::com const std::int64_t *group_size, const std::vector &dependencies = {}); -sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, float alpha, const float *x, - std::int64_t incx, std::int64_t stridex, float *y, std::int64_t incy, - std::int64_t stridey, std::int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, value_or_pointer alpha, + const float *x, std::int64_t incx, std::int64_t stridex, float *y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies = {}); -sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, double alpha, const double *x, - std::int64_t incx, std::int64_t stridex, double *y, std::int64_t incy, - std::int64_t stridey, std::int64_t batch_size, +sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, value_or_pointer alpha, + const double *x, std::int64_t incx, std::int64_t stridex, double *y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies = {}); -sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, std::complex alpha, +sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, value_or_pointer> alpha, const std::complex *x, std::int64_t incx, std::int64_t stridex, std::complex *y, std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, const std::vector &dependencies = {}); -sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, std::complex alpha, - const std::complex *x, std::int64_t incx, std::int64_t stridex, - std::complex *y, std::int64_t incy, std::int64_t stridey, - std::int64_t batch_size, const std::vector &dependencies = {}); +sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, + value_or_pointer> alpha, const std::complex *x, + std::int64_t incx, std::int64_t stridex, std::complex *y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, const transpose *transa, const transpose *transb, const std::int64_t *m, const std::int64_t *n, const std::int64_t *k, @@ -2071,16 +2123,16 @@ sycl::event gemm_batch(sycl::queue &queue, const transpose *transa, const transp sycl::event gemm_batch(sycl::queue &queue, const transpose *transa, const transpose *transb, const std::int64_t *m, const std::int64_t *n, const std::int64_t *k, const float *alpha, const bfloat16 **a, const std::int64_t *lda, - const bfloat16 **b, const std::int64_t *ldb, const float *beta, bfloat16 **c, - const std::int64_t *ldc, std::int64_t group_count, + const bfloat16 **b, const std::int64_t *ldb, const float *beta, + bfloat16 **c, const std::int64_t *ldc, std::int64_t group_count, const std::int64_t *groupsize, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, const transpose *transa, const transpose *transb, const std::int64_t *m, const std::int64_t *n, const std::int64_t *k, const float *alpha, const bfloat16 **a, const std::int64_t *lda, - const bfloat16 **b, const std::int64_t *ldb, const float *beta, float **c, - const std::int64_t *ldc, std::int64_t group_count, + const bfloat16 **b, const std::int64_t *ldb, const float *beta, + float **c, const std::int64_t *ldc, std::int64_t group_count, const std::int64_t *groupsize, const std::vector &dependencies = {}); @@ -2101,100 +2153,105 @@ sycl::event gemm_batch(sycl::queue &queue, const transpose *transa, const transp const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const float *a, + std::int64_t n, std::int64_t k, value_or_pointer alpha, const float *a, std::int64_t lda, std::int64_t stride_a, const float *b, std::int64_t ldb, - std::int64_t stride_b, float beta, float *c, std::int64_t ldc, + std::int64_t stride_b, value_or_pointer beta, float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, double alpha, const double *a, + std::int64_t n, std::int64_t k, value_or_pointer alpha, const double *a, std::int64_t lda, std::int64_t stride_a, const double *b, std::int64_t ldb, - std::int64_t stride_b, double beta, double *c, std::int64_t ldc, - std::int64_t stride_c, std::int64_t batch_size, + std::int64_t stride_b, value_or_pointer beta, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, std::complex alpha, + std::int64_t n, std::int64_t k, value_or_pointer> alpha, const std::complex *a, std::int64_t lda, std::int64_t stride_a, const std::complex *b, std::int64_t ldb, std::int64_t stride_b, - std::complex beta, std::complex *c, std::int64_t ldc, - std::int64_t stride_c, std::int64_t batch_size, + value_or_pointer> beta, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, std::complex alpha, + std::int64_t n, std::int64_t k, value_or_pointer> alpha, const std::complex *a, std::int64_t lda, std::int64_t stride_a, const std::complex *b, std::int64_t ldb, std::int64_t stride_b, - std::complex beta, std::complex *c, std::int64_t ldc, - std::int64_t stride_c, std::int64_t batch_size, + value_or_pointer> beta, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, sycl::half alpha, const sycl::half *a, - std::int64_t lda, std::int64_t stride_a, const sycl::half *b, - std::int64_t ldb, std::int64_t stride_b, sycl::half beta, sycl::half *c, - std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + std::int64_t n, std::int64_t k, value_or_pointer alpha, + const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, + value_or_pointer beta, sycl::half *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, - std::int64_t lda, std::int64_t stride_a, const sycl::half *b, - std::int64_t ldb, std::int64_t stride_b, float beta, float *c, - std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, - const std::vector &dependencies = {}); + std::int64_t n, std::int64_t k, value_or_pointer alpha, + const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, + value_or_pointer beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const bfloat16 *a, - std::int64_t lda, std::int64_t stride_a, const bfloat16 *b, std::int64_t ldb, - std::int64_t stride_b, float beta, bfloat16 *c, std::int64_t ldc, + std::int64_t n, std::int64_t k, value_or_pointer alpha, + const bfloat16 *a, std::int64_t lda, std::int64_t stride_a, + const bfloat16 *b, std::int64_t ldb, std::int64_t stride_b, + value_or_pointer beta, bfloat16 *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const bfloat16 *a, - std::int64_t lda, std::int64_t stride_a, const bfloat16 *b, std::int64_t ldb, - std::int64_t stride_b, float beta, float *c, std::int64_t ldc, - std::int64_t stride_c, std::int64_t batch_size, - const std::vector &dependencies = {}); + std::int64_t n, std::int64_t k, value_or_pointer alpha, + const bfloat16 *a, std::int64_t lda, std::int64_t stride_a, + const bfloat16 *b, std::int64_t ldb, std::int64_t stride_b, + value_or_pointer beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, - std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, - std::int64_t ldb, std::int64_t stride_b, float beta, std::int32_t *c, - std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + std::int64_t n, std::int64_t k, value_or_pointer alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + value_or_pointer beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, - std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, - std::int64_t ldb, std::int64_t stride_b, float beta, float *c, - std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, - const std::vector &dependencies = {}); + std::int64_t n, std::int64_t k, value_or_pointer alpha, + const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, + value_or_pointer beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, float alpha, const float *a, - std::int64_t lda, std::int64_t stride_a, float *b, std::int64_t ldb, - std::int64_t stride_b, std::int64_t batch_size, + diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + diag unit_diag, std::int64_t m, std::int64_t n, value_or_pointer alpha, const double *a, std::int64_t lda, std::int64_t stride_a, double *b, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *a, std::int64_t lda, std::int64_t stride_a, - std::complex *b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, const std::vector &dependencies = {}); + diag unit_diag, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, std::int64_t m, std::int64_t n, std::complex alpha, - const std::complex *a, std::int64_t lda, std::int64_t stride_a, - std::complex *b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, const std::vector &dependencies = {}); + diag unit_diag, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies = {}); sycl::event trsm_batch(sycl::queue &queue, const side *left_right, const uplo *upper_lower, const transpose *trans, const diag *unit_diag, const std::int64_t *m, @@ -2227,183 +2284,211 @@ sycl::event trsm_batch(sycl::queue &queue, const side *left_right, const uplo *u const std::vector &dependencies = {}); sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, - float *b, std::int64_t ldb, std::int64_t stride_b, + value_or_pointer alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, float *b, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - double alpha, const double *a, std::int64_t lda, std::int64_t stride_a, - double *b, std::int64_t ldb, std::int64_t stride_b, - std::int64_t batch_size, + value_or_pointer alpha, const double *a, std::int64_t lda, + std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, const std::complex *a, + value_or_pointer> alpha, const std::complex *a, std::int64_t lda, std::int64_t stride_a, std::complex *b, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, const std::complex *a, + value_or_pointer> alpha, const std::complex *a, std::int64_t lda, std::int64_t stride_a, std::complex *b, std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - float alpha, float *ab, std::int64_t lda, std::int64_t ldb, - std::int64_t stride, std::int64_t batch_size, + value_or_pointer alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - double alpha, double *ab, std::int64_t lda, std::int64_t ldb, - std::int64_t stride, std::int64_t batch_size, + value_or_pointer alpha, double *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, std::complex *ab, std::int64_t lda, - std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + value_or_pointer> alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, std::complex *ab, std::int64_t lda, - std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + value_or_pointer> alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, float alpha, const float *a, std::int64_t lda, - std::int64_t stride_a, float beta, const float *b, std::int64_t ldb, - std::int64_t stride_b, float *c, std::int64_t ldc, std::int64_t stride_c, - std::int64_t batch_size, + std::int64_t n, value_or_pointer alpha, const float *a, + std::int64_t lda, std::int64_t stride_a, value_or_pointer beta, + const float *b, std::int64_t ldb, std::int64_t stride_b, float *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, double alpha, const double *a, std::int64_t lda, - std::int64_t stride_a, double beta, const double *b, std::int64_t ldb, - std::int64_t stride_b, double *c, std::int64_t ldc, std::int64_t stride_c, - std::int64_t batch_size, + std::int64_t n, value_or_pointer alpha, const double *a, + std::int64_t lda, std::int64_t stride_a, value_or_pointer beta, + const double *b, std::int64_t ldb, std::int64_t stride_b, double *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::complex alpha, const std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::complex beta, - const std::complex *b, std::int64_t ldb, std::int64_t stride_b, - std::complex *c, std::int64_t ldc, std::int64_t stride_c, - std::int64_t batch_size, + std::int64_t n, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + value_or_pointer> beta, const std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::complex alpha, const std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::complex beta, - const std::complex *b, std::int64_t ldb, std::int64_t stride_b, - std::complex *c, std::int64_t ldc, std::int64_t stride_c, - std::int64_t batch_size, + std::int64_t n, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + value_or_pointer> beta, const std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies = {}); sycl::event omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - float alpha, const float *a, std::int64_t lda, float *b, std::int64_t ldb, - const std::vector &dependencies = {}); + value_or_pointer alpha, const float *a, std::int64_t lda, float *b, + std::int64_t ldb, const std::vector &dependencies = {}); sycl::event omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - double alpha, const double *a, std::int64_t lda, double *b, std::int64_t ldb, - const std::vector &dependencies = {}); + value_or_pointer alpha, const double *a, std::int64_t lda, double *b, + std::int64_t ldb, const std::vector &dependencies = {}); sycl::event omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, const std::complex *a, std::int64_t lda, - std::complex *b, std::int64_t ldb, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, const std::vector &dependencies = {}); sycl::event omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, const std::complex *a, std::int64_t lda, - std::complex *b, std::int64_t ldb, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, const std::vector &dependencies = {}); +sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + value_or_pointer alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + value_or_pointer alpha, const double *a, std::int64_t lda, + std::int64_t stridea, double *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + value_or_pointer> alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies = {}); + sycl::event imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - float alpha, float *ab, std::int64_t lda, std::int64_t ldb, + value_or_pointer alpha, float *ab, std::int64_t lda, std::int64_t ldb, const std::vector &dependencies = {}); sycl::event imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - double alpha, double *ab, std::int64_t lda, std::int64_t ldb, + value_or_pointer alpha, double *ab, std::int64_t lda, std::int64_t ldb, const std::vector &dependencies = {}); sycl::event imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, std::complex *ab, std::int64_t lda, - std::int64_t ldb, const std::vector &dependencies = {}); + value_or_pointer> alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, + const std::vector &dependencies = {}); sycl::event imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, - std::complex alpha, std::complex *ab, std::int64_t lda, - std::int64_t ldb, const std::vector &dependencies = {}); + value_or_pointer> alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, + const std::vector &dependencies = {}); sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, float alpha, const float *a, std::int64_t lda, float beta, - const float *b, std::int64_t ldb, float *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, value_or_pointer alpha, const float *a, std::int64_t lda, + value_or_pointer beta, const float *b, std::int64_t ldb, float *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, double alpha, const double *a, std::int64_t lda, double beta, - const double *b, std::int64_t ldb, double *c, std::int64_t ldc, - const std::vector &dependencies = {}); + std::int64_t n, value_or_pointer alpha, const double *a, std::int64_t lda, + value_or_pointer beta, const double *b, std::int64_t ldb, double *c, + std::int64_t ldc, const std::vector &dependencies = {}); sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::complex alpha, const std::complex *a, - std::int64_t lda, std::complex beta, const std::complex *b, + std::int64_t n, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, + value_or_pointer> beta, const std::complex *b, std::int64_t ldb, std::complex *c, std::int64_t ldc, const std::vector &dependencies = {}); sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, - std::int64_t n, std::complex alpha, const std::complex *a, - std::int64_t lda, std::complex beta, const std::complex *b, + std::int64_t n, value_or_pointer> alpha, + const std::complex *a, std::int64_t lda, + value_or_pointer> beta, const std::complex *b, std::int64_t ldb, std::complex *c, std::int64_t ldc, const std::vector &dependencies = {}); -sycl::event omatcopy_batch(sycl::queue& queue, const transpose* trans, const std::int64_t* m, - const std::int64_t* n, const float* alpha, const float** a, - const std::int64_t* lda, float** b, const std::int64_t* ldb, - std::int64_t group_count, const std::int64_t* groupsize, - const std::vector& dependencies = {}); - -sycl::event omatcopy_batch(sycl::queue& queue, const transpose* trans, const std::int64_t* m, - const std::int64_t* n, const double* alpha, const double** a, - const std::int64_t* lda, double** b, const std::int64_t* ldb, - std::int64_t group_count, const std::int64_t* groupsize, - const std::vector& dependencies = {}); - -sycl::event omatcopy_batch(sycl::queue& queue, const transpose* trans, const std::int64_t* m, - const std::int64_t* n, const std::complex* alpha, - const std::complex** a, const std::int64_t* lda, - std::complex** b, const std::int64_t* ldb, - std::int64_t group_count, const std::int64_t* groupsize, - const std::vector& dependencies = {}); - -sycl::event omatcopy_batch(sycl::queue& queue, const transpose* trans, const std::int64_t* m, - const std::int64_t* n, const std::complex* alpha, - const std::complex** a, const std::int64_t* lda, - std::complex** b, const std::int64_t* ldb, - std::int64_t group_count, const std::int64_t* groupsize, - const std::vector& dependencies = {}); - -sycl::event imatcopy_batch(sycl::queue& queue, const transpose* trans, const std::int64_t* m, - const std::int64_t* n, const float* alpha, float** ab, - const std::int64_t* lda, const std::int64_t* ldb, - std::int64_t group_count, const std::int64_t* groupsize, - const std::vector& dependencies = {}); - -sycl::event imatcopy_batch(sycl::queue& queue, const transpose* trans, const std::int64_t* m, - const std::int64_t* n, const double* alpha, double** ab, - const std::int64_t* lda, const std::int64_t* ldb, - std::int64_t group_count, const std::int64_t* groupsize, - const std::vector& dependencies = {}); - -sycl::event imatcopy_batch(sycl::queue& queue, const transpose* trans, const std::int64_t* m, - const std::int64_t* n, const std::complex* alpha, - std::complex** ab, const std::int64_t* lda, - const std::int64_t* ldb, std::int64_t group_count, - const std::int64_t* groupsize, - const std::vector& dependencies = {}); - -sycl::event imatcopy_batch(sycl::queue& queue, const transpose* trans, const std::int64_t* m, - const std::int64_t* n, const std::complex* alpha, - std::complex** ab, const std::int64_t* lda, - const std::int64_t* ldb, std::int64_t group_count, - const std::int64_t* groupsize, - const std::vector& dependencies = {}); +sycl::event omatcopy_batch(sycl::queue &queue, const transpose *trans, const std::int64_t *m, + const std::int64_t *n, const float *alpha, const float **a, + const std::int64_t *lda, float **b, const std::int64_t *ldb, + std::int64_t group_count, const std::int64_t *groupsize, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, const transpose *trans, const std::int64_t *m, + const std::int64_t *n, const double *alpha, const double **a, + const std::int64_t *lda, double **b, const std::int64_t *ldb, + std::int64_t group_count, const std::int64_t *groupsize, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, const transpose *trans, const std::int64_t *m, + const std::int64_t *n, const std::complex *alpha, + const std::complex **a, const std::int64_t *lda, + std::complex **b, const std::int64_t *ldb, + std::int64_t group_count, const std::int64_t *groupsize, + const std::vector &dependencies = {}); + +sycl::event omatcopy_batch(sycl::queue &queue, const transpose *trans, const std::int64_t *m, + const std::int64_t *n, const std::complex *alpha, + const std::complex **a, const std::int64_t *lda, + std::complex **b, const std::int64_t *ldb, + std::int64_t group_count, const std::int64_t *groupsize, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, const transpose *trans, const std::int64_t *m, + const std::int64_t *n, const float *alpha, float **ab, + const std::int64_t *lda, const std::int64_t *ldb, + std::int64_t group_count, const std::int64_t *groupsize, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, const transpose *trans, const std::int64_t *m, + const std::int64_t *n, const double *alpha, double **ab, + const std::int64_t *lda, const std::int64_t *ldb, + std::int64_t group_count, const std::int64_t *groupsize, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, const transpose *trans, const std::int64_t *m, + const std::int64_t *n, const std::complex *alpha, + std::complex **ab, const std::int64_t *lda, + const std::int64_t *ldb, std::int64_t group_count, + const std::int64_t *groupsize, + const std::vector &dependencies = {}); + +sycl::event imatcopy_batch(sycl::queue &queue, const transpose *trans, const std::int64_t *m, + const std::int64_t *n, const std::complex *alpha, + std::complex **ab, const std::int64_t *lda, + const std::int64_t *ldb, std::int64_t group_count, + const std::int64_t *groupsize, + const std::vector &dependencies = {}); diff --git a/src/blas/backends/mkl_common/mkl_extensions.cxx b/src/blas/backends/mkl_common/mkl_extensions.cxx index f20268554..4672af5c7 100644 --- a/src/blas/backends/mkl_common/mkl_extensions.cxx +++ b/src/blas/backends/mkl_common/mkl_extensions.cxx @@ -105,6 +105,31 @@ void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co blas_major::omatcopy(queue, trans, m, n, alpha, a, lda, b, ldb); } +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb) { + throw unimplemented("blas", "omatcopy2", ""); +} + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb) { + throw unimplemented("blas", "omatcopy2", ""); +} + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, std::int64_t stridea, + sycl::buffer, 1> &b, int64_t ldb, std::int64_t strideb) { + throw unimplemented("blas", "omatcopy2", ""); +} + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, int64_t ldb, + std::int64_t strideb) { + throw unimplemented("blas", "omatcopy2", ""); +} + void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { blas_major::imatcopy(queue, trans, m, n, alpha, ab, lda, ldb); @@ -249,6 +274,32 @@ sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, return blas_major::omatcopy(queue, trans, m, n, alpha, a, lda, b, ldb, dependencies); } +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, std::int64_t stridea, float *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy2", ""); +} + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, std::int64_t stridea, double *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy2", ""); +} + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + std::int64_t stridea, std::complex *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy2", ""); +} + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + std::int64_t stridea, std::complex *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy2", ""); +} + sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, const std::vector &dependencies) { diff --git a/src/blas/backends/mklcpu/CMakeLists.txt b/src/blas/backends/mklcpu/CMakeLists.txt index 691e61b3d..322741d26 100644 --- a/src/blas/backends/mklcpu/CMakeLists.txt +++ b/src/blas/backends/mklcpu/CMakeLists.txt @@ -20,13 +20,12 @@ set(LIB_NAME onemkl_blas_mklcpu) set(LIB_OBJ ${LIB_NAME}_obj) -set(USE_DPCPP_API ON) -find_package(MKL REQUIRED) set(SOURCES mklcpu_level1.cpp mklcpu_level2.cpp mklcpu_level3.cpp mklcpu_batch.cpp mklcpu_extensions.cpp $<$: mklcpu_wrappers.cpp>) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemkl_backend_libs_blas ${LIB_NAME}) if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES}) endif() @@ -36,12 +35,16 @@ target_include_directories(${LIB_OBJ} ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/src/include ${CMAKE_BINARY_DIR}/bin - ${MKL_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_C} ${MKL_LINK_SYCL}) +if(TARGET MKL::MKL_SYCL::BLAS) + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_SYCL::BLAS) +else() + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_DPCPP) +endif() set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/src/blas/backends/mklcpu/mklcpu_batch.cpp b/src/blas/backends/mklcpu/mklcpu_batch.cpp index 9dd231629..5ecf4cc69 100644 --- a/src/blas/backends/mklcpu/mklcpu_batch.cpp +++ b/src/blas/backends/mklcpu/mklcpu_batch.cpp @@ -25,6 +25,7 @@ #include "oneapi/mkl/blas/detail/mklcpu/onemkl_blas_mklcpu.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "../mkl_common/mkl_blas_backend.hpp" namespace oneapi { diff --git a/src/blas/backends/mklcpu/mklcpu_extensions.cpp b/src/blas/backends/mklcpu/mklcpu_extensions.cpp index d9d122eac..215addd5e 100644 --- a/src/blas/backends/mklcpu/mklcpu_extensions.cpp +++ b/src/blas/backends/mklcpu/mklcpu_extensions.cpp @@ -25,6 +25,7 @@ #include "oneapi/mkl/blas/detail/mklcpu/onemkl_blas_mklcpu.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "../mkl_common/mkl_blas_backend.hpp" namespace oneapi { diff --git a/src/blas/backends/mklgpu/CMakeLists.txt b/src/blas/backends/mklgpu/CMakeLists.txt index 07d9b2e7c..c971d1afd 100644 --- a/src/blas/backends/mklgpu/CMakeLists.txt +++ b/src/blas/backends/mklgpu/CMakeLists.txt @@ -19,25 +19,29 @@ set(LIB_NAME onemkl_blas_mklgpu) set(LIB_OBJ ${LIB_NAME}_obj) -find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT mklgpu_level1.cpp mklgpu_level2.cpp mklgpu_level3.cpp mklgpu_batch.cpp mklgpu_extensions.cpp $<$: mklgpu_wrappers.cpp> ) +add_dependencies(onemkl_backend_libs_blas ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/src/include ${CMAKE_BINARY_DIR}/bin - ${MKL_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) +if(TARGET MKL::MKL_SYCL::BLAS) + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_SYCL::BLAS) +else() + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_DPCPP) +endif() set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/src/blas/backends/mklgpu/mklgpu_batch.cpp b/src/blas/backends/mklgpu/mklgpu_batch.cpp index d859a3b78..bad2db82c 100644 --- a/src/blas/backends/mklgpu/mklgpu_batch.cpp +++ b/src/blas/backends/mklgpu/mklgpu_batch.cpp @@ -25,6 +25,7 @@ #include "oneapi/mkl/blas/detail/mklgpu/onemkl_blas_mklgpu.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "../mkl_common/mkl_blas_backend.hpp" namespace oneapi { diff --git a/src/blas/backends/mklgpu/mklgpu_extensions.cpp b/src/blas/backends/mklgpu/mklgpu_extensions.cpp index 8ea7d5d6e..c4b1635c8 100644 --- a/src/blas/backends/mklgpu/mklgpu_extensions.cpp +++ b/src/blas/backends/mklgpu/mklgpu_extensions.cpp @@ -25,6 +25,7 @@ #include "oneapi/mkl/blas/detail/mklgpu/onemkl_blas_mklgpu.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "../mkl_common/mkl_blas_backend.hpp" namespace oneapi { diff --git a/src/blas/backends/netlib/CMakeLists.txt b/src/blas/backends/netlib/CMakeLists.txt index 4be16691a..fd5275fc0 100644 --- a/src/blas/backends/netlib/CMakeLists.txt +++ b/src/blas/backends/netlib/CMakeLists.txt @@ -29,7 +29,8 @@ set(SOURCES netlib_common.hpp ) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT ${SOURCES}) - +add_dependencies(onemkl_backend_libs_blas ${LIB_NAME}) + if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES}) endif() @@ -40,6 +41,7 @@ target_include_directories(${LIB_OBJ} ${PROJECT_SOURCE_DIR}/src/include ${CMAKE_BINARY_DIR}/bin ${NETLIB_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) diff --git a/src/blas/backends/netlib/netlib_batch.cxx b/src/blas/backends/netlib/netlib_batch.cxx index a029a60bc..7a2839dd4 100644 --- a/src/blas/backends/netlib/netlib_batch.cxx +++ b/src/blas/backends/netlib/netlib_batch.cxx @@ -279,6 +279,45 @@ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t #endif } +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, sycl::buffer &a, int64_t lda, + int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, + float beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, + int64_t batch_size) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, @@ -983,6 +1022,45 @@ sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, #endif } +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const sycl::half **a, int64_t *lda, + const sycl::half **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, float **c, int64_t *ldc, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, + int64_t *n, int64_t *k, float *alpha, const std::int8_t **a, int64_t *lda, + const std::int8_t **b, int64_t *ldb, float *beta, std::int32_t **c, + int64_t *ldc, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, float alpha, const float *a, int64_t lda, int64_t stride_a, const float *b, int64_t ldb, int64_t stride_b, @@ -1052,6 +1130,45 @@ sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, i #endif } +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, int64_t stride_a, + const sycl::half *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, float *c, + int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + +sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const std::int8_t *a, int64_t lda, int64_t stride_a, + const std::int8_t *b, int64_t ldb, int64_t stride_b, float beta, + std::int32_t *c, int64_t ldc, int64_t stride_c, int64_t batch_size, + const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "gemm_batch", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#endif +} + sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, diff --git a/src/blas/backends/netlib/netlib_common.hpp b/src/blas/backends/netlib/netlib_common.hpp index ec2023e79..3a69c70f8 100644 --- a/src/blas/backends/netlib/netlib_common.hpp +++ b/src/blas/backends/netlib/netlib_common.hpp @@ -32,6 +32,8 @@ #include "oneapi/mkl/blas/detail/netlib/onemkl_blas_netlib.hpp" #include "oneapi/mkl/types.hpp" +#define GET_MULTI_PTR template get_multi_ptr().get_raw() + namespace oneapi { namespace mkl { namespace blas { diff --git a/src/blas/backends/netlib/netlib_extensions.cxx b/src/blas/backends/netlib/netlib_extensions.cxx index dd89796b1..8e94cb880 100644 --- a/src/blas/backends/netlib/netlib_extensions.cxx +++ b/src/blas/backends/netlib/netlib_extensions.cxx @@ -161,6 +161,51 @@ void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co #endif } +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy2", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy2", "for row_major layout"); +#endif +} + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy2", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy2", "for row_major layout"); +#endif +} + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, std::int64_t stridea, + sycl::buffer, 1> &b, int64_t ldb, std::int64_t strideb) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy2", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy2", "for row_major layout"); +#endif +} + +void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, int64_t ldb, + std::int64_t strideb) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy2", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy2", "for row_major layout"); +#endif +} + void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { #ifdef COLUMN_MAJOR @@ -397,6 +442,52 @@ sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, #endif } +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, + const float *a, int64_t lda, std::int64_t stridea, float *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy2", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy2", "for row_major layout"); +#endif +} + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, + const double *a, int64_t lda, std::int64_t stridea, double *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy2", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy2", "for row_major layout"); +#endif +} + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + std::int64_t stridea, std::complex *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy2", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy2", "for row_major layout"); +#endif +} + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + std::int64_t stridea, std::complex *b, int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { +#ifdef COLUMN_MAJOR + throw unimplemented("blas", "omatcopy2", "for column_major layout"); +#endif +#ifdef ROW_MAJOR + throw unimplemented("blas", "omatcopy2", "for row_major layout"); +#endif +} + sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, const std::vector &dependencies) { diff --git a/src/blas/backends/netlib/netlib_level1.cpp b/src/blas/backends/netlib/netlib_level1.cpp index 7b3ea5827..59830db81 100644 --- a/src/blas/backends/netlib/netlib_level1.cpp +++ b/src/blas/backends/netlib/netlib_level1.cpp @@ -185,7 +185,7 @@ void cblas_zdrot(const int n, std::complex *zx, const int incx, std::com } } -void cblas_crotg(std::complex *ca, std::complex *cb, float *c, +void cblas_crotg(std::complex *ca, const std::complex *cb, float *c, std::complex *s) { if (std::abs(ca[0]) == 0) { c[0] = 0.0; @@ -203,7 +203,7 @@ void cblas_crotg(std::complex *ca, std::complex *cb, float *c, } } -void cblas_zrotg(std::complex *ca, std::complex *cb, double *c, +void cblas_zrotg(std::complex *ca, const std::complex *cb, double *c, std::complex *s) { if (std::abs(ca[0]) == 0) { c[0] = 0.0; diff --git a/src/blas/backends/netlib/netlib_level1.cxx b/src/blas/backends/netlib/netlib_level1.cxx index ec9cc0f45..9f953dc5b 100644 --- a/src/blas/backends/netlib/netlib_level1.cxx +++ b/src/blas/backends/netlib/netlib_level1.cxx @@ -26,7 +26,7 @@ void asum(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_sasum((const int)n, accessor_x.get_pointer(), (const int)std::abs(incx)); + ::cblas_sasum((const int)n, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } @@ -38,55 +38,55 @@ void asum(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t inc auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_dasum((const int)n, accessor_x.get_pointer(), (const int)std::abs(incx)); + ::cblas_dasum((const int)n, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } -void asum(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer &result) { +void asum(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_scasum((const int)n, accessor_x.get_pointer(), (const int)std::abs(incx)); + ::cblas_scasum((const int)n, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } -void asum(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer &result) { +void asum(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_dzasum((const int)n, accessor_x.get_pointer(), (const int)std::abs(incx)); + ::cblas_dzasum((const int)n, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } -void axpy(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, - int64_t incx, sycl::buffer &y, int64_t incy) { +void axpy(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, + sycl::buffer &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_saxpy((const int)n, (const float)alpha, accessor_x.get_pointer(), - (const int)incx, accessor_y.get_pointer(), (const int)incy); + ::cblas_saxpy((const int)n, (const float)alpha, accessor_x.GET_MULTI_PTR, + (const int)incx, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void axpy(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, - int64_t incx, sycl::buffer &y, int64_t incy) { +void axpy(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, int64_t incx, + sycl::buffer &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_daxpy((const int)n, (const double)alpha, accessor_x.get_pointer(), - (const int)incx, accessor_y.get_pointer(), (const int)incy); + ::cblas_daxpy((const int)n, (const double)alpha, accessor_x.GET_MULTI_PTR, + (const int)incx, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -98,8 +98,8 @@ void axpy(sycl::queue &queue, int64_t n, std::complex alpha, auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_caxpy((const int)n, (const void *)&alpha, accessor_x.get_pointer(), - (const int)incx, accessor_y.get_pointer(), (const int)incy); + ::cblas_caxpy((const int)n, (const void *)&alpha, accessor_x.GET_MULTI_PTR, + (const int)incx, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -111,14 +111,14 @@ void axpy(sycl::queue &queue, int64_t n, std::complex alpha, auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_zaxpy((const int)n, (const void *)&alpha, accessor_x.get_pointer(), - (const int)incx, accessor_y.get_pointer(), (const int)incy); + ::cblas_zaxpy((const int)n, (const void *)&alpha, accessor_x.GET_MULTI_PTR, + (const int)incx, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void axpby(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, - int64_t incx, float beta, sycl::buffer &y, int64_t incy) { +void axpby(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, + float beta, sycl::buffer &y, int64_t incy) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "axpby", "for column_major layout"); #endif @@ -127,8 +127,8 @@ void axpby(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x #endif } -void axpby(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, - int64_t incx, double beta, sycl::buffer &y, int64_t incy) { +void axpby(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, int64_t incx, + double beta, sycl::buffer &y, int64_t incy) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "axpby", "for column_major layout"); #endif @@ -165,8 +165,8 @@ void copy(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_scopy((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_scopy((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -177,32 +177,32 @@ void copy(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t inc auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_dcopy((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_dcopy((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void copy(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy) { +void copy(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_ccopy((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_ccopy((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void copy(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy) { +void copy(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_zcopy((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_zcopy((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -215,8 +215,8 @@ void dot(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_sdot((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_sdot((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -229,8 +229,8 @@ void dot(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_ddot((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_ddot((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -243,68 +243,68 @@ void dot(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_dsdot((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_dsdot((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void dotc(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy, +void dotc(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy, sycl::buffer, 1> &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - ::cblas_cdotc_sub((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, - accessor_result.get_pointer()); + ::cblas_cdotc_sub((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, + accessor_result.GET_MULTI_PTR); }); }); } -void dotc(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy, +void dotc(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy, sycl::buffer, 1> &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - ::cblas_zdotc_sub((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, - accessor_result.get_pointer()); + ::cblas_zdotc_sub((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, + accessor_result.GET_MULTI_PTR); }); }); } -void dotu(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy, +void dotu(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy, sycl::buffer, 1> &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - ::cblas_cdotu_sub((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, - accessor_result.get_pointer()); + ::cblas_cdotu_sub((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, + accessor_result.GET_MULTI_PTR); }); }); } -void dotu(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy, +void dotu(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy, sycl::buffer, 1> &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - ::cblas_zdotu_sub((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, - accessor_result.get_pointer()); + ::cblas_zdotu_sub((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, + accessor_result.GET_MULTI_PTR); }); }); } @@ -315,7 +315,7 @@ void iamin(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t inc auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - accessor_result[0] = ::cblas_isamin((int)n, accessor_x.get_pointer(), (int)incx); + accessor_result[0] = ::cblas_isamin((int)n, accessor_x.GET_MULTI_PTR, (int)incx); }); }); } @@ -326,29 +326,29 @@ void iamin(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t in auto accessor_x = x.template get_access(cgh); auto accessor_result = result.template get_access(cgh); host_task(cgh, [=]() { - accessor_result[0] = ::cblas_idamin((int)n, accessor_x.get_pointer(), (int)incx); + accessor_result[0] = ::cblas_idamin((int)n, accessor_x.GET_MULTI_PTR, (int)incx); }); }); } -void iamin(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer &result) { +void iamin(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - accessor_result[0] = ::cblas_icamin((int)n, accessor_x.get_pointer(), (int)incx); + accessor_result[0] = ::cblas_icamin((int)n, accessor_x.GET_MULTI_PTR, (int)incx); }); }); } -void iamin(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer &result) { +void iamin(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - accessor_result[0] = ::cblas_izamin((int)n, accessor_x.get_pointer(), (int)incx); + accessor_result[0] = ::cblas_izamin((int)n, accessor_x.GET_MULTI_PTR, (int)incx); }); }); } @@ -359,7 +359,7 @@ void iamax(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t inc auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - accessor_result[0] = ::cblas_isamax((int)n, accessor_x.get_pointer(), (int)incx); + accessor_result[0] = ::cblas_isamax((int)n, accessor_x.GET_MULTI_PTR, (int)incx); }); }); } @@ -370,29 +370,29 @@ void iamax(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t in auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - accessor_result[0] = ::cblas_idamax((int)n, accessor_x.get_pointer(), (int)incx); + accessor_result[0] = ::cblas_idamax((int)n, accessor_x.GET_MULTI_PTR, (int)incx); }); }); } -void iamax(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer &result) { +void iamax(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - accessor_result[0] = ::cblas_icamax((int)n, accessor_x.get_pointer(), (int)incx); + accessor_result[0] = ::cblas_icamax((int)n, accessor_x.GET_MULTI_PTR, (int)incx); }); }); } -void iamax(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer &result) { +void iamax(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { - accessor_result[0] = ::cblas_izamax((int)n, accessor_x.get_pointer(), (int)incx); + accessor_result[0] = ::cblas_izamax((int)n, accessor_x.GET_MULTI_PTR, (int)incx); }); }); } @@ -404,7 +404,7 @@ void nrm2(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx auto accessor_result = result.template get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_snrm2((const int)n, accessor_x.get_pointer(), (const int)std::abs(incx)); + ::cblas_snrm2((const int)n, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } @@ -416,31 +416,31 @@ void nrm2(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t inc auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_dnrm2((const int)n, accessor_x.get_pointer(), (const int)std::abs(incx)); + ::cblas_dnrm2((const int)n, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } -void nrm2(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer &result) { +void nrm2(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_scnrm2((const int)n, accessor_x.get_pointer(), (const int)std::abs(incx)); + ::cblas_scnrm2((const int)n, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } -void nrm2(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer &result) { +void nrm2(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_dznrm2((const int)n, accessor_x.get_pointer(), (const int)std::abs(incx)); + ::cblas_dznrm2((const int)n, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } @@ -451,8 +451,8 @@ void rot(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_srot((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, (const float)c, (const float)s); + ::cblas_srot((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, (const float)c, (const float)s); }); }); } @@ -463,35 +463,33 @@ void rot(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_drot((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, (const float)c, (const float)s); + ::cblas_drot((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, (const float)c, (const float)s); }); }); } -void rot(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy, float c, - float s) { +void rot(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy, float c, float s) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_csrot((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, (const float)c, + ::cblas_csrot((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, (const float)c, (const float)s); }); }); } -void rot(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy, double c, - double s) { +void rot(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy, double c, double s) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_zdrot((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, (const double)c, + ::cblas_zdrot((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, (const double)c, (const double)s); }); }); @@ -505,8 +503,8 @@ void rotg(sycl::queue &queue, sycl::buffer &a, sycl::buffer auto accessor_c = c.get_access(cgh); auto accessor_s = s.get_access(cgh); host_task(cgh, [=]() { - ::cblas_srotg(accessor_a.get_pointer(), accessor_b.get_pointer(), - accessor_c.get_pointer(), accessor_s.get_pointer()); + ::cblas_srotg(accessor_a.GET_MULTI_PTR, accessor_b.GET_MULTI_PTR, + accessor_c.GET_MULTI_PTR, accessor_s.GET_MULTI_PTR); }); }); } @@ -519,8 +517,8 @@ void rotg(sycl::queue &queue, sycl::buffer &a, sycl::buffer(cgh); auto accessor_s = s.get_access(cgh); host_task(cgh, [=]() { - ::cblas_drotg(accessor_a.get_pointer(), accessor_b.get_pointer(), - accessor_c.get_pointer(), accessor_s.get_pointer()); + ::cblas_drotg(accessor_a.GET_MULTI_PTR, accessor_b.GET_MULTI_PTR, + accessor_c.GET_MULTI_PTR, accessor_s.GET_MULTI_PTR); }); }); } @@ -534,8 +532,8 @@ void rotg(sycl::queue &queue, sycl::buffer, 1> &a, auto accessor_c = c.get_access(cgh); auto accessor_s = s.get_access(cgh); host_task(cgh, [=]() { - ::cblas_crotg(accessor_a.get_pointer(), accessor_b.get_pointer(), - accessor_c.get_pointer(), accessor_s.get_pointer()); + ::cblas_crotg(accessor_a.GET_MULTI_PTR, accessor_b.GET_MULTI_PTR, + accessor_c.GET_MULTI_PTR, accessor_s.GET_MULTI_PTR); }); }); } @@ -549,8 +547,8 @@ void rotg(sycl::queue &queue, sycl::buffer, 1> &a, auto accessor_c = c.get_access(cgh); auto accessor_s = s.get_access(cgh); host_task(cgh, [=]() { - ::cblas_zrotg(accessor_a.get_pointer(), accessor_b.get_pointer(), - accessor_c.get_pointer(), accessor_s.get_pointer()); + ::cblas_zrotg(accessor_a.GET_MULTI_PTR, accessor_b.GET_MULTI_PTR, + accessor_c.GET_MULTI_PTR, accessor_s.GET_MULTI_PTR); }); }); } @@ -562,8 +560,8 @@ void rotm(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx auto accessor_y = y.get_access(cgh); auto accessor_param = param.get_access(cgh); host_task(cgh, [=]() { - ::cblas_srotm((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_param.get_pointer()); + ::cblas_srotm((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_param.GET_MULTI_PTR); }); }); } @@ -575,8 +573,8 @@ void rotm(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t inc auto accessor_y = y.get_access(cgh); auto accessor_param = param.get_access(cgh); host_task(cgh, [=]() { - ::cblas_drotm((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_param.get_pointer()); + ::cblas_drotm((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_param.GET_MULTI_PTR); }); }); } @@ -589,8 +587,8 @@ void rotmg(sycl::queue &queue, sycl::buffer &d1, sycl::buffer(cgh); auto accessor_param = param.get_access(cgh); host_task(cgh, [=]() { - ::cblas_srotmg(accessor_d1.get_pointer(), accessor_d2.get_pointer(), - accessor_x1.get_pointer(), (float)y1, accessor_param.get_pointer()); + ::cblas_srotmg(accessor_d1.GET_MULTI_PTR, accessor_d2.GET_MULTI_PTR, + accessor_x1.GET_MULTI_PTR, (float)y1, accessor_param.GET_MULTI_PTR); }); }); } @@ -603,29 +601,27 @@ void rotmg(sycl::queue &queue, sycl::buffer &d1, sycl::buffer(cgh); auto accessor_param = param.get_access(cgh); host_task(cgh, [=]() { - ::cblas_drotmg(accessor_d1.get_pointer(), accessor_d2.get_pointer(), - accessor_x1.get_pointer(), (double)y1, accessor_param.get_pointer()); + ::cblas_drotmg(accessor_d1.GET_MULTI_PTR, accessor_d2.GET_MULTI_PTR, + accessor_x1.GET_MULTI_PTR, (double)y1, accessor_param.GET_MULTI_PTR); }); }); } -void scal(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, - int64_t incx) { +void scal(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { - ::cblas_sscal((const int)n, (const float)alpha, accessor_x.get_pointer(), + ::cblas_sscal((const int)n, (const float)alpha, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } -void scal(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, - int64_t incx) { +void scal(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, int64_t incx) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { - ::cblas_dscal((const int)n, (const double)alpha, accessor_x.get_pointer(), + ::cblas_dscal((const int)n, (const double)alpha, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); @@ -636,18 +632,18 @@ void scal(sycl::queue &queue, int64_t n, std::complex alpha, queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { - ::cblas_cscal((const int)n, (const void *)&alpha, accessor_x.get_pointer(), + ::cblas_cscal((const int)n, (const void *)&alpha, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } -void scal(sycl::queue &queue, int64_t n, float alpha, - sycl::buffer, 1> &x, int64_t incx) { +void scal(sycl::queue &queue, int64_t n, float alpha, sycl::buffer, 1> &x, + int64_t incx) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { - ::cblas_csscal((const int)n, (const float)alpha, accessor_x.get_pointer(), + ::cblas_csscal((const int)n, (const float)alpha, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); @@ -658,34 +654,33 @@ void scal(sycl::queue &queue, int64_t n, std::complex alpha, queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { - ::cblas_zscal((const int)n, (const void *)&alpha, accessor_x.get_pointer(), + ::cblas_zscal((const int)n, (const void *)&alpha, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } -void scal(sycl::queue &queue, int64_t n, double alpha, - sycl::buffer, 1> &x, int64_t incx) { +void scal(sycl::queue &queue, int64_t n, double alpha, sycl::buffer, 1> &x, + int64_t incx) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { - ::cblas_zdscal((const int)n, (const double)alpha, accessor_x.get_pointer(), + ::cblas_zdscal((const int)n, (const double)alpha, accessor_x.GET_MULTI_PTR, (const int)std::abs(incx)); }); }); } -void sdsdot(sycl::queue &queue, int64_t n, float sb, sycl::buffer &x, - int64_t incx, sycl::buffer &y, int64_t incy, - sycl::buffer &result) { +void sdsdot(sycl::queue &queue, int64_t n, float sb, sycl::buffer &x, int64_t incx, + sycl::buffer &y, int64_t incy, sycl::buffer &result) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); auto accessor_result = result.get_access(cgh); host_task(cgh, [=]() { accessor_result[0] = - ::cblas_sdsdot((const int)n, (const float)sb, accessor_x.get_pointer(), - (const int)incx, accessor_y.get_pointer(), (const int)incy); + ::cblas_sdsdot((const int)n, (const float)sb, accessor_x.GET_MULTI_PTR, + (const int)incx, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -696,8 +691,8 @@ void swap(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_sswap((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_sswap((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -708,32 +703,32 @@ void swap(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t inc auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_dswap((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_dswap((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void swap(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy) { +void swap(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_cswap((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_cswap((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void swap(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, sycl::buffer, 1> &y, int64_t incy) { +void swap(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { - ::cblas_zswap((const int)n, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy); + ::cblas_zswap((const int)n, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -741,7 +736,7 @@ void swap(sycl::queue &queue, int64_t n, sycl::buffer, 1> & // USM APIs sycl::event asum(sycl::queue &queue, int64_t n, const float *x, int64_t incx, float *result, - const std::vector &dependencies) { + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -753,8 +748,8 @@ sycl::event asum(sycl::queue &queue, int64_t n, const float *x, int64_t incx, fl return done; } -sycl::event asum(sycl::queue &queue, int64_t n, const double *x, int64_t incx, - double *result, const std::vector &dependencies) { +sycl::event asum(sycl::queue &queue, int64_t n, const double *x, int64_t incx, double *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -767,7 +762,7 @@ sycl::event asum(sycl::queue &queue, int64_t n, const double *x, int64_t incx, } sycl::event asum(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - float *result, const std::vector &dependencies) { + float *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -780,7 +775,7 @@ sycl::event asum(sycl::queue &queue, int64_t n, const std::complex *x, in } sycl::event asum(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - double *result, const std::vector &dependencies) { + double *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -792,8 +787,8 @@ sycl::event asum(sycl::queue &queue, int64_t n, const std::complex *x, i return done; } -sycl::event axpy(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, - float *y, int64_t incy, const std::vector &dependencies) { +sycl::event axpy(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, float *y, + int64_t incy, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -807,7 +802,7 @@ sycl::event axpy(sycl::queue &queue, int64_t n, float alpha, const float *x, int } sycl::event axpy(sycl::queue &queue, int64_t n, double alpha, const double *x, int64_t incx, - double *y, int64_t incy, const std::vector &dependencies) { + double *y, int64_t incy, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -822,8 +817,8 @@ sycl::event axpy(sycl::queue &queue, int64_t n, double alpha, const double *x, i } sycl::event axpy(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, std::complex *y, - int64_t incy, const std::vector &dependencies) { + const std::complex *x, int64_t incx, std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -838,8 +833,8 @@ sycl::event axpy(sycl::queue &queue, int64_t n, std::complex alpha, } sycl::event axpy(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, std::complex *y, - int64_t incy, const std::vector &dependencies) { + const std::complex *x, int64_t incx, std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -854,8 +849,8 @@ sycl::event axpy(sycl::queue &queue, int64_t n, std::complex alpha, } sycl::event axpby(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, - float beta, float *y, int64_t incy, - const std::vector &dependencies) { + float beta, float *y, int64_t incy, + const std::vector &dependencies) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "axpby", "for column_major layout"); #endif @@ -864,9 +859,9 @@ sycl::event axpby(sycl::queue &queue, int64_t n, float alpha, const float *x, in #endif } -sycl::event axpby(sycl::queue &queue, int64_t n, double alpha, const double *x, - int64_t incx, double beta, double *y, int64_t incy, - const std::vector &dependencies) { +sycl::event axpby(sycl::queue &queue, int64_t n, double alpha, const double *x, int64_t incx, + double beta, double *y, int64_t incy, + const std::vector &dependencies) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "axpby", "for column_major layout"); #endif @@ -876,9 +871,9 @@ sycl::event axpby(sycl::queue &queue, int64_t n, double alpha, const double *x, } sycl::event axpby(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, std::complex beta, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "axpby", "for column_major layout"); #endif @@ -888,9 +883,9 @@ sycl::event axpby(sycl::queue &queue, int64_t n, std::complex alpha, } sycl::event axpby(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, std::complex beta, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "axpby", "for column_major layout"); #endif @@ -900,7 +895,7 @@ sycl::event axpby(sycl::queue &queue, int64_t n, std::complex alpha, } sycl::event copy(sycl::queue &queue, int64_t n, const float *x, int64_t incx, float *y, - int64_t incy, const std::vector &dependencies) { + int64_t incy, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -913,7 +908,7 @@ sycl::event copy(sycl::queue &queue, int64_t n, const float *x, int64_t incx, fl } sycl::event copy(sycl::queue &queue, int64_t n, const double *x, int64_t incx, double *y, - int64_t incy, const std::vector &dependencies) { + int64_t incy, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -926,8 +921,8 @@ sycl::event copy(sycl::queue &queue, int64_t n, const double *x, int64_t incx, d } sycl::event copy(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -940,8 +935,8 @@ sycl::event copy(sycl::queue &queue, int64_t n, const std::complex *x, in } sycl::event copy(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -954,7 +949,7 @@ sycl::event copy(sycl::queue &queue, int64_t n, const std::complex *x, i } sycl::event dot(sycl::queue &queue, int64_t n, const float *x, int64_t incx, const float *y, - int64_t incy, float *result, const std::vector &dependencies) { + int64_t incy, float *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -967,9 +962,8 @@ sycl::event dot(sycl::queue &queue, int64_t n, const float *x, int64_t incx, con return done; } -sycl::event dot(sycl::queue &queue, int64_t n, const double *x, int64_t incx, - const double *y, int64_t incy, double *result, - const std::vector &dependencies) { +sycl::event dot(sycl::queue &queue, int64_t n, const double *x, int64_t incx, const double *y, + int64_t incy, double *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -983,8 +977,7 @@ sycl::event dot(sycl::queue &queue, int64_t n, const double *x, int64_t incx, } sycl::event dot(sycl::queue &queue, int64_t n, const float *x, int64_t incx, const float *y, - int64_t incy, double *result, - const std::vector &dependencies) { + int64_t incy, double *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -998,8 +991,8 @@ sycl::event dot(sycl::queue &queue, int64_t n, const float *x, int64_t incx, con } sycl::event dotc(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - const std::complex *y, int64_t incy, std::complex *result, - const std::vector &dependencies) { + const std::complex *y, int64_t incy, std::complex *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1013,8 +1006,8 @@ sycl::event dotc(sycl::queue &queue, int64_t n, const std::complex *x, in } sycl::event dotc(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - const std::complex *y, int64_t incy, std::complex *result, - const std::vector &dependencies) { + const std::complex *y, int64_t incy, std::complex *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1028,8 +1021,8 @@ sycl::event dotc(sycl::queue &queue, int64_t n, const std::complex *x, i } sycl::event dotu(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - const std::complex *y, int64_t incy, std::complex *result, - const std::vector &dependencies) { + const std::complex *y, int64_t incy, std::complex *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1043,8 +1036,8 @@ sycl::event dotu(sycl::queue &queue, int64_t n, const std::complex *x, in } sycl::event dotu(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - const std::complex *y, int64_t incy, std::complex *result, - const std::vector &dependencies) { + const std::complex *y, int64_t incy, std::complex *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1057,8 +1050,8 @@ sycl::event dotu(sycl::queue &queue, int64_t n, const std::complex *x, i return done; } -sycl::event iamin(sycl::queue &queue, int64_t n, const float *x, int64_t incx, - int64_t *result, const std::vector &dependencies) { +sycl::event iamin(sycl::queue &queue, int64_t n, const float *x, int64_t incx, int64_t *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1070,8 +1063,8 @@ sycl::event iamin(sycl::queue &queue, int64_t n, const float *x, int64_t incx, return done; } -sycl::event iamin(sycl::queue &queue, int64_t n, const double *x, int64_t incx, - int64_t *result, const std::vector &dependencies) { +sycl::event iamin(sycl::queue &queue, int64_t n, const double *x, int64_t incx, int64_t *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1084,7 +1077,7 @@ sycl::event iamin(sycl::queue &queue, int64_t n, const double *x, int64_t incx, } sycl::event iamin(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - int64_t *result, const std::vector &dependencies) { + int64_t *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1096,9 +1089,8 @@ sycl::event iamin(sycl::queue &queue, int64_t n, const std::complex *x, i return done; } -sycl::event iamin(sycl::queue &queue, int64_t n, const std::complex *x, - int64_t incx, int64_t *result, - const std::vector &dependencies) { +sycl::event iamin(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, + int64_t *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1110,8 +1102,8 @@ sycl::event iamin(sycl::queue &queue, int64_t n, const std::complex *x, return done; } -sycl::event iamax(sycl::queue &queue, int64_t n, const float *x, int64_t incx, - int64_t *result, const std::vector &dependencies) { +sycl::event iamax(sycl::queue &queue, int64_t n, const float *x, int64_t incx, int64_t *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1123,8 +1115,8 @@ sycl::event iamax(sycl::queue &queue, int64_t n, const float *x, int64_t incx, return done; } -sycl::event iamax(sycl::queue &queue, int64_t n, const double *x, int64_t incx, - int64_t *result, const std::vector &dependencies) { +sycl::event iamax(sycl::queue &queue, int64_t n, const double *x, int64_t incx, int64_t *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1137,7 +1129,7 @@ sycl::event iamax(sycl::queue &queue, int64_t n, const double *x, int64_t incx, } sycl::event iamax(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - int64_t *result, const std::vector &dependencies) { + int64_t *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1149,9 +1141,8 @@ sycl::event iamax(sycl::queue &queue, int64_t n, const std::complex *x, i return done; } -sycl::event iamax(sycl::queue &queue, int64_t n, const std::complex *x, - int64_t incx, int64_t *result, - const std::vector &dependencies) { +sycl::event iamax(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, + int64_t *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1164,7 +1155,7 @@ sycl::event iamax(sycl::queue &queue, int64_t n, const std::complex *x, } sycl::event nrm2(sycl::queue &queue, int64_t n, const float *x, int64_t incx, float *result, - const std::vector &dependencies) { + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1176,8 +1167,8 @@ sycl::event nrm2(sycl::queue &queue, int64_t n, const float *x, int64_t incx, fl return done; } -sycl::event nrm2(sycl::queue &queue, int64_t n, const double *x, int64_t incx, - double *result, const std::vector &dependencies) { +sycl::event nrm2(sycl::queue &queue, int64_t n, const double *x, int64_t incx, double *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1190,7 +1181,7 @@ sycl::event nrm2(sycl::queue &queue, int64_t n, const double *x, int64_t incx, } sycl::event nrm2(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - float *result, const std::vector &dependencies) { + float *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1203,7 +1194,7 @@ sycl::event nrm2(sycl::queue &queue, int64_t n, const std::complex *x, in } sycl::event nrm2(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - double *result, const std::vector &dependencies) { + double *result, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1215,9 +1206,8 @@ sycl::event nrm2(sycl::queue &queue, int64_t n, const std::complex *x, i return done; } -sycl::event rot(sycl::queue &queue, int64_t n, float *x, int64_t incx, float *y, - int64_t incy, float c, float s, - const std::vector &dependencies) { +sycl::event rot(sycl::queue &queue, int64_t n, float *x, int64_t incx, float *y, int64_t incy, + float c, float s, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1231,9 +1221,8 @@ sycl::event rot(sycl::queue &queue, int64_t n, float *x, int64_t incx, float *y, return done; } -sycl::event rot(sycl::queue &queue, int64_t n, double *x, int64_t incx, double *y, - int64_t incy, double c, double s, - const std::vector &dependencies) { +sycl::event rot(sycl::queue &queue, int64_t n, double *x, int64_t incx, double *y, int64_t incy, + double c, double s, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1248,8 +1237,8 @@ sycl::event rot(sycl::queue &queue, int64_t n, double *x, int64_t incx, double * } sycl::event rot(sycl::queue &queue, int64_t n, std::complex *x, int64_t incx, - std::complex *y, int64_t incy, float c, float s, - const std::vector &dependencies) { + std::complex *y, int64_t incy, float c, float s, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1264,8 +1253,8 @@ sycl::event rot(sycl::queue &queue, int64_t n, std::complex *x, int64_t i } sycl::event rot(sycl::queue &queue, int64_t n, std::complex *x, int64_t incx, - std::complex *y, int64_t incy, double c, double s, - const std::vector &dependencies) { + std::complex *y, int64_t incy, double c, double s, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1280,7 +1269,7 @@ sycl::event rot(sycl::queue &queue, int64_t n, std::complex *x, int64_t } sycl::event rotg(sycl::queue &queue, float *a, float *b, float *c, float *s, - const std::vector &dependencies) { + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1292,7 +1281,7 @@ sycl::event rotg(sycl::queue &queue, float *a, float *b, float *c, float *s, } sycl::event rotg(sycl::queue &queue, double *a, double *b, double *c, double *s, - const std::vector &dependencies) { + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1303,9 +1292,8 @@ sycl::event rotg(sycl::queue &queue, double *a, double *b, double *c, double *s, return done; } -sycl::event rotg(sycl::queue &queue, std::complex *a, std::complex *b, - float *c, std::complex *s, - const std::vector &dependencies) { +sycl::event rotg(sycl::queue &queue, std::complex *a, std::complex *b, float *c, + std::complex *s, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1316,9 +1304,8 @@ sycl::event rotg(sycl::queue &queue, std::complex *a, std::complex return done; } -sycl::event rotg(sycl::queue &queue, std::complex *a, std::complex *b, - double *c, std::complex *s, - const std::vector &dependencies) { +sycl::event rotg(sycl::queue &queue, std::complex *a, std::complex *b, double *c, + std::complex *s, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1329,8 +1316,8 @@ sycl::event rotg(sycl::queue &queue, std::complex *a, std::complex &dependencies) { +sycl::event rotm(sycl::queue &queue, int64_t n, float *x, int64_t incx, float *y, int64_t incy, + float *param, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1343,9 +1330,8 @@ sycl::event rotm(sycl::queue &queue, int64_t n, float *x, int64_t incx, float *y return done; } -sycl::event rotm(sycl::queue &queue, int64_t n, double *x, int64_t incx, double *y, - int64_t incy, double *param, - const std::vector &dependencies) { +sycl::event rotm(sycl::queue &queue, int64_t n, double *x, int64_t incx, double *y, int64_t incy, + double *param, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1358,8 +1344,8 @@ sycl::event rotm(sycl::queue &queue, int64_t n, double *x, int64_t incx, double return done; } -sycl::event rotmg(sycl::queue &queue, float *d1, float *d2, float *x1, float y1, - float *param, const std::vector &dependencies) { +sycl::event rotmg(sycl::queue &queue, float *d1, float *d2, float *x1, float y1, float *param, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1371,8 +1357,8 @@ sycl::event rotmg(sycl::queue &queue, float *d1, float *d2, float *x1, float y1, return done; } -sycl::event rotmg(sycl::queue &queue, double *d1, double *d2, double *x1, double y1, - double *param, const std::vector &dependencies) { +sycl::event rotmg(sycl::queue &queue, double *d1, double *d2, double *x1, double y1, double *param, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1385,7 +1371,7 @@ sycl::event rotmg(sycl::queue &queue, double *d1, double *d2, double *x1, double } sycl::event scal(sycl::queue &queue, int64_t n, float alpha, float *x, int64_t incx, - const std::vector &dependencies) { + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1399,7 +1385,7 @@ sycl::event scal(sycl::queue &queue, int64_t n, float alpha, float *x, int64_t i } sycl::event scal(sycl::queue &queue, int64_t n, double alpha, double *x, int64_t incx, - const std::vector &dependencies) { + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1412,9 +1398,8 @@ sycl::event scal(sycl::queue &queue, int64_t n, double alpha, double *x, int64_t return done; } -sycl::event scal(sycl::queue &queue, int64_t n, std::complex alpha, - std::complex *x, int64_t incx, - const std::vector &dependencies) { +sycl::event scal(sycl::queue &queue, int64_t n, std::complex alpha, std::complex *x, + int64_t incx, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1427,8 +1412,8 @@ sycl::event scal(sycl::queue &queue, int64_t n, std::complex alpha, return done; } -sycl::event scal(sycl::queue &queue, int64_t n, float alpha, std::complex *x, - int64_t incx, const std::vector &dependencies) { +sycl::event scal(sycl::queue &queue, int64_t n, float alpha, std::complex *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1441,9 +1426,8 @@ sycl::event scal(sycl::queue &queue, int64_t n, float alpha, std::complex return done; } -sycl::event scal(sycl::queue &queue, int64_t n, std::complex alpha, - std::complex *x, int64_t incx, - const std::vector &dependencies) { +sycl::event scal(sycl::queue &queue, int64_t n, std::complex alpha, std::complex *x, + int64_t incx, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1456,8 +1440,8 @@ sycl::event scal(sycl::queue &queue, int64_t n, std::complex alpha, return done; } -sycl::event scal(sycl::queue &queue, int64_t n, double alpha, std::complex *x, - int64_t incx, const std::vector &dependencies) { +sycl::event scal(sycl::queue &queue, int64_t n, double alpha, std::complex *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1471,8 +1455,8 @@ sycl::event scal(sycl::queue &queue, int64_t n, double alpha, std::complex &dependencies) { + const float *y, int64_t incy, float *result, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1486,8 +1470,8 @@ sycl::event sdsdot(sycl::queue &queue, int64_t n, float sb, const float *x, int6 return done; } -sycl::event swap(sycl::queue &queue, int64_t n, float *x, int64_t incx, float *y, - int64_t incy, const std::vector &dependencies) { +sycl::event swap(sycl::queue &queue, int64_t n, float *x, int64_t incx, float *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1499,8 +1483,8 @@ sycl::event swap(sycl::queue &queue, int64_t n, float *x, int64_t incx, float *y return done; } -sycl::event swap(sycl::queue &queue, int64_t n, double *x, int64_t incx, double *y, - int64_t incy, const std::vector &dependencies) { +sycl::event swap(sycl::queue &queue, int64_t n, double *x, int64_t incx, double *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1513,8 +1497,8 @@ sycl::event swap(sycl::queue &queue, int64_t n, double *x, int64_t incx, double } sycl::event swap(sycl::queue &queue, int64_t n, std::complex *x, int64_t incx, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1527,8 +1511,8 @@ sycl::event swap(sycl::queue &queue, int64_t n, std::complex *x, int64_t } sycl::event swap(sycl::queue &queue, int64_t n, std::complex *x, int64_t incx, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { diff --git a/src/blas/backends/netlib/netlib_level2.cxx b/src/blas/backends/netlib/netlib_level2.cxx index f4ce7e17f..156ed133b 100644 --- a/src/blas/backends/netlib/netlib_level2.cxx +++ b/src/blas/backends/netlib/netlib_level2.cxx @@ -29,8 +29,8 @@ void gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, host_task(cgh, [=]() { ::cblas_sgbmv(MAJOR, convert_to_cblas_trans(trans), (const int)m, (const int)n, (const int)kl, (const int)ku, (const float)alpha, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), - (const int)incx, (const float)beta, accessor_y.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, + (const int)incx, (const float)beta, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); @@ -46,8 +46,8 @@ void gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, host_task(cgh, [=]() { ::cblas_dgbmv(MAJOR, convert_to_cblas_trans(trans), (const int)m, (const int)n, (const int)kl, (const int)ku, (const double)alpha, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), - (const int)incx, (const double)beta, accessor_y.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, + (const int)incx, (const double)beta, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); @@ -64,8 +64,8 @@ void gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, host_task(cgh, [=]() { ::cblas_cgbmv(MAJOR, convert_to_cblas_trans(trans), (const int)m, (const int)n, (const int)kl, (const int)ku, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), - (const int)incx, (const void *)&beta, accessor_y.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, + (const int)incx, (const void *)&beta, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); @@ -82,8 +82,8 @@ void gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, host_task(cgh, [=]() { ::cblas_zgbmv(MAJOR, convert_to_cblas_trans(trans), (const int)m, (const int)n, (const int)kl, (const int)ku, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), - (const int)incx, (const void *)&beta, accessor_y.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, + (const int)incx, (const void *)&beta, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); @@ -98,9 +98,9 @@ void gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_sgemv(MAJOR, convert_to_cblas_trans(trans), (const int)m, (const int)n, - (const float)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const float)beta, - accessor_y.get_pointer(), (const int)incy); + (const float)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const float)beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -114,9 +114,9 @@ void gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alph auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_dgemv(MAJOR, convert_to_cblas_trans(trans), (const int)m, (const int)n, - (const double)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const double)beta, - accessor_y.get_pointer(), (const int)incy); + (const double)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const double)beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -131,9 +131,9 @@ void gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::comple auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_cgemv(MAJOR, convert_to_cblas_trans(trans), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const void *)&beta, - accessor_y.get_pointer(), (const int)incy); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const void *)&beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -148,9 +148,9 @@ void gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::comple auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_zgemv(MAJOR, convert_to_cblas_trans(trans), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const void *)&beta, - accessor_y.get_pointer(), (const int)incy); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const void *)&beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -164,8 +164,8 @@ void ger(sycl::queue &queue, int64_t m, int64_t n, float alpha, sycl::buffer(cgh); host_task(cgh, [=]() { ::cblas_sger(MAJOR, (const int)m, (const int)n, (const float)alpha, - accessor_x.get_pointer(), (const int)incx, accessor_y.get_pointer(), - (const int)incy, accessor_a.get_pointer(), (const int)lda); + accessor_x.GET_MULTI_PTR, (const int)incx, accessor_y.GET_MULTI_PTR, + (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } @@ -179,8 +179,8 @@ void ger(sycl::queue &queue, int64_t m, int64_t n, double alpha, sycl::buffer(cgh); host_task(cgh, [=]() { ::cblas_dger(MAJOR, (const int)m, (const int)n, (const double)alpha, - accessor_x.get_pointer(), (const int)incx, accessor_y.get_pointer(), - (const int)incy, accessor_a.get_pointer(), (const int)lda); + accessor_x.GET_MULTI_PTR, (const int)incx, accessor_y.GET_MULTI_PTR, + (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } @@ -195,8 +195,8 @@ void gerc(sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_cgerc(MAJOR, (const int)m, (const int)n, (const void *)&alpha, - accessor_x.get_pointer(), (const int)incx, accessor_y.get_pointer(), - (const int)incy, accessor_a.get_pointer(), (const int)lda); + accessor_x.GET_MULTI_PTR, (const int)incx, accessor_y.GET_MULTI_PTR, + (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } @@ -211,8 +211,8 @@ void gerc(sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_zgerc(MAJOR, (const int)m, (const int)n, (const void *)&alpha, - accessor_x.get_pointer(), (const int)incx, accessor_y.get_pointer(), - (const int)incy, accessor_a.get_pointer(), (const int)lda); + accessor_x.GET_MULTI_PTR, (const int)incx, accessor_y.GET_MULTI_PTR, + (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } @@ -227,8 +227,8 @@ void geru(sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_cgeru(MAJOR, (const int)m, (const int)n, (const void *)&alpha, - accessor_x.get_pointer(), (const int)incx, accessor_y.get_pointer(), - (const int)incy, accessor_a.get_pointer(), (const int)lda); + accessor_x.GET_MULTI_PTR, (const int)incx, accessor_y.GET_MULTI_PTR, + (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } @@ -243,8 +243,8 @@ void geru(sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_zgeru(MAJOR, (const int)m, (const int)n, (const void *)&alpha, - accessor_x.get_pointer(), (const int)incx, accessor_y.get_pointer(), - (const int)incy, accessor_a.get_pointer(), (const int)lda); + accessor_x.GET_MULTI_PTR, (const int)incx, accessor_y.GET_MULTI_PTR, + (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } @@ -259,15 +259,15 @@ void hbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, std::compl auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_chbmv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, (const int)k, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const void *)&beta, - accessor_y.get_pointer(), (const int)incy); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const void *)&beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void hbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, +void hbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, std::complex alpha, + sycl::buffer, 1> &a, int64_t lda, sycl::buffer, 1> &x, int64_t incx, std::complex beta, sycl::buffer, 1> &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { @@ -276,9 +276,9 @@ void hbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_zhbmv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, (const int)k, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const void *)&beta, - accessor_y.get_pointer(), (const int)incy); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const void *)&beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -293,9 +293,9 @@ void hemv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex a auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_chemv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const void *)&beta, - accessor_y.get_pointer(), (const int)incy); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const void *)&beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -310,9 +310,9 @@ void hemv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_zhemv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const void *)&beta, - accessor_y.get_pointer(), (const int)incy); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const void *)&beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -325,8 +325,8 @@ void her(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_cher(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const float)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_a.get_pointer(), (const int)lda); + (const float)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } @@ -339,8 +339,8 @@ void her(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_zher(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const double)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_a.get_pointer(), (const int)lda); + (const double)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } @@ -355,8 +355,8 @@ void her2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex a auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_cher2(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const void *)&alpha, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_a.get_pointer(), + (const void *)&alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); @@ -372,8 +372,8 @@ void her2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_zher2(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const void *)&alpha, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_a.get_pointer(), + (const void *)&alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); @@ -389,25 +389,25 @@ void hpmv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex a auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_chpmv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const void *)&alpha, accessor_ap.get_pointer(), accessor_x.get_pointer(), - (const int)incx, (const void *)&beta, accessor_y.get_pointer(), + (const void *)&alpha, accessor_ap.GET_MULTI_PTR, accessor_x.GET_MULTI_PTR, + (const int)incx, (const void *)&beta, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } void hpmv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, - sycl::buffer, 1> &ap, - sycl::buffer, 1> &x, int64_t incx, std::complex beta, - sycl::buffer, 1> &y, int64_t incy) { + sycl::buffer, 1> &ap, sycl::buffer, 1> &x, + int64_t incx, std::complex beta, sycl::buffer, 1> &y, + int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_ap = ap.get_access(cgh); auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_zhpmv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const void *)&alpha, accessor_ap.get_pointer(), accessor_x.get_pointer(), - (const int)incx, (const void *)&beta, accessor_y.get_pointer(), + (const void *)&alpha, accessor_ap.GET_MULTI_PTR, accessor_x.GET_MULTI_PTR, + (const int)incx, (const void *)&beta, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); @@ -421,8 +421,8 @@ void hpr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, auto accessor_ap = ap.get_access(cgh); host_task(cgh, [=]() { ::cblas_chpr(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const float)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_ap.get_pointer()); + (const float)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_ap.GET_MULTI_PTR); }); }); } @@ -435,8 +435,8 @@ void hpr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, auto accessor_ap = ap.get_access(cgh); host_task(cgh, [=]() { ::cblas_zhpr(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const double)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_ap.get_pointer()); + (const double)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_ap.GET_MULTI_PTR); }); }); } @@ -451,8 +451,8 @@ void hpr2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex a auto accessor_ap = ap.get_access(cgh); host_task(cgh, [=]() { ::cblas_chpr2(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const void *)&alpha, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_ap.get_pointer()); + (const void *)&alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_ap.GET_MULTI_PTR); }); }); } @@ -467,8 +467,8 @@ void hpr2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex auto accessor_ap = ap.get_access(cgh); host_task(cgh, [=]() { ::cblas_zhpr2(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const void *)&alpha, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_ap.get_pointer()); + (const void *)&alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_ap.GET_MULTI_PTR); }); }); } @@ -482,9 +482,9 @@ void sbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, float alph auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_ssbmv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, (const int)k, - (const float)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const float)beta, - accessor_y.get_pointer(), (const int)incy); + (const float)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const float)beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } @@ -498,187 +498,184 @@ void sbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, double alp auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_dsbmv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, (const int)k, - (const double)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const double)beta, - accessor_y.get_pointer(), (const int)incy); + (const double)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const double)beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void spmv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - sycl::buffer &ap, sycl::buffer &x, int64_t incx, float beta, - sycl::buffer &y, int64_t incy) { +void spmv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, sycl::buffer &ap, + sycl::buffer &x, int64_t incx, float beta, sycl::buffer &y, + int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_ap = ap.get_access(cgh); auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_sspmv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const float)alpha, accessor_ap.get_pointer(), accessor_x.get_pointer(), - (const int)incx, (const float)beta, accessor_y.get_pointer(), + (const float)alpha, accessor_ap.GET_MULTI_PTR, accessor_x.GET_MULTI_PTR, + (const int)incx, (const float)beta, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } void spmv(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - sycl::buffer &ap, sycl::buffer &x, int64_t incx, - double beta, sycl::buffer &y, int64_t incy) { + sycl::buffer &ap, sycl::buffer &x, int64_t incx, double beta, + sycl::buffer &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_ap = ap.get_access(cgh); auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_dspmv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const double)alpha, accessor_ap.get_pointer(), accessor_x.get_pointer(), - (const int)incx, (const double)beta, accessor_y.get_pointer(), + (const double)alpha, accessor_ap.GET_MULTI_PTR, accessor_x.GET_MULTI_PTR, + (const int)incx, (const double)beta, accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void spr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - sycl::buffer &x, int64_t incx, sycl::buffer &ap) { +void spr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, sycl::buffer &x, + int64_t incx, sycl::buffer &ap) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_ap = ap.get_access(cgh); host_task(cgh, [=]() { ::cblas_sspr(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const float)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_ap.get_pointer()); + (const float)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_ap.GET_MULTI_PTR); }); }); } -void spr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - sycl::buffer &x, int64_t incx, sycl::buffer &ap) { +void spr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, sycl::buffer &x, + int64_t incx, sycl::buffer &ap) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_ap = ap.get_access(cgh); host_task(cgh, [=]() { ::cblas_dspr(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const double)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_ap.get_pointer()); + (const double)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_ap.GET_MULTI_PTR); }); }); } -void spr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, - sycl::buffer &ap) { +void spr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, sycl::buffer &x, + int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &ap) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); auto accessor_ap = ap.get_access(cgh); host_task(cgh, [=]() { ::cblas_sspr2(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const float)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_ap.get_pointer()); + (const float)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_ap.GET_MULTI_PTR); }); }); } -void spr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - sycl::buffer &x, int64_t incx, sycl::buffer &y, - int64_t incy, sycl::buffer &ap) { +void spr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, sycl::buffer &x, + int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &ap) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); auto accessor_ap = ap.get_access(cgh); host_task(cgh, [=]() { ::cblas_dspr2(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const double)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_ap.get_pointer()); + (const double)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_ap.GET_MULTI_PTR); }); }); } -void symv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx, - float beta, sycl::buffer &y, int64_t incy) { +void symv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, sycl::buffer &a, + int64_t lda, sycl::buffer &x, int64_t incx, float beta, + sycl::buffer &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_ssymv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const float)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const float)beta, - accessor_y.get_pointer(), (const int)incy); + (const float)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const float)beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void symv(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx, - double beta, sycl::buffer &y, int64_t incy) { +void symv(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, sycl::buffer &a, + int64_t lda, sycl::buffer &x, int64_t incx, double beta, + sycl::buffer &y, int64_t incy) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); host_task(cgh, [=]() { ::cblas_dsymv(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const double)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_x.get_pointer(), (const int)incx, (const double)beta, - accessor_y.get_pointer(), (const int)incy); + (const double)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_x.GET_MULTI_PTR, (const int)incx, (const double)beta, + accessor_y.GET_MULTI_PTR, (const int)incy); }); }); } -void syr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - sycl::buffer &x, int64_t incx, sycl::buffer &a, int64_t lda) { +void syr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, sycl::buffer &x, + int64_t incx, sycl::buffer &a, int64_t lda) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_ssyr(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const float)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_a.get_pointer(), (const int)lda); + (const float)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } -void syr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - sycl::buffer &x, int64_t incx, sycl::buffer &a, - int64_t lda) { +void syr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, sycl::buffer &x, + int64_t incx, sycl::buffer &a, int64_t lda) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_dsyr(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const double)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_a.get_pointer(), (const int)lda); + (const double)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } -void syr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, - sycl::buffer &a, int64_t lda) { +void syr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, sycl::buffer &x, + int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &a, + int64_t lda) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_ssyr2(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const float)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_a.get_pointer(), + (const float)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); } -void syr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - sycl::buffer &x, int64_t incx, sycl::buffer &y, - int64_t incy, sycl::buffer &a, int64_t lda) { +void syr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, sycl::buffer &x, + int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &a, + int64_t lda) { queue.submit([&](sycl::handler &cgh) { auto accessor_x = x.get_access(cgh); auto accessor_y = y.get_access(cgh); auto accessor_a = a.get_access(cgh); host_task(cgh, [=]() { ::cblas_dsyr2(MAJOR, convert_to_cblas_uplo(upper_lower), (const int)n, - (const double)alpha, accessor_x.get_pointer(), (const int)incx, - accessor_y.get_pointer(), (const int)incy, accessor_a.get_pointer(), + (const double)alpha, accessor_x.GET_MULTI_PTR, (const int)incx, + accessor_y.GET_MULTI_PTR, (const int)incy, accessor_a.GET_MULTI_PTR, (const int)lda); }); }); @@ -693,7 +690,7 @@ void tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, host_task(cgh, [=]() { ::cblas_stbmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), convert_to_cblas_diag(unit_diag), (const int)n, (const int)k, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); @@ -708,7 +705,7 @@ void tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, host_task(cgh, [=]() { ::cblas_dtbmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), convert_to_cblas_diag(unit_diag), (const int)n, (const int)k, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); @@ -723,7 +720,7 @@ void tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, host_task(cgh, [=]() { ::cblas_ctbmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), convert_to_cblas_diag(unit_diag), (const int)n, (const int)k, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); @@ -738,7 +735,7 @@ void tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, host_task(cgh, [=]() { ::cblas_ztbmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), convert_to_cblas_diag(unit_diag), (const int)n, (const int)k, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); @@ -753,7 +750,7 @@ void tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, host_task(cgh, [=]() { ::cblas_stbsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), convert_to_cblas_diag(unit_diag), (const int)n, (const int)k, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); @@ -768,7 +765,7 @@ void tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, host_task(cgh, [=]() { ::cblas_dtbsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), convert_to_cblas_diag(unit_diag), (const int)n, (const int)k, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); @@ -783,7 +780,7 @@ void tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, host_task(cgh, [=]() { ::cblas_ctbsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), convert_to_cblas_diag(unit_diag), (const int)n, (const int)k, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); @@ -798,7 +795,7 @@ void tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, host_task(cgh, [=]() { ::cblas_ztbsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), convert_to_cblas_diag(unit_diag), (const int)n, (const int)k, - accessor_a.get_pointer(), (const int)lda, accessor_x.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); @@ -811,8 +808,8 @@ void tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_stpmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.get_pointer(), - accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.GET_MULTI_PTR, + accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } @@ -824,8 +821,8 @@ void tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_dtpmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.get_pointer(), - accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.GET_MULTI_PTR, + accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } @@ -838,22 +835,22 @@ void tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_ctpmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.get_pointer(), - accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.GET_MULTI_PTR, + accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } void tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, - sycl::buffer, 1> &ap, - sycl::buffer, 1> &x, int64_t incx) { + sycl::buffer, 1> &ap, sycl::buffer, 1> &x, + int64_t incx) { queue.submit([&](sycl::handler &cgh) { auto accessor_ap = ap.get_access(cgh); auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_ztpmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.get_pointer(), - accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.GET_MULTI_PTR, + accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } @@ -865,8 +862,8 @@ void tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_stpsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.get_pointer(), - accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.GET_MULTI_PTR, + accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } @@ -878,8 +875,8 @@ void tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_dtpsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.get_pointer(), - accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.GET_MULTI_PTR, + accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } @@ -892,22 +889,22 @@ void tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_ctpsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.get_pointer(), - accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.GET_MULTI_PTR, + accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } void tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, - sycl::buffer, 1> &ap, - sycl::buffer, 1> &x, int64_t incx) { + sycl::buffer, 1> &ap, sycl::buffer, 1> &x, + int64_t incx) { queue.submit([&](sycl::handler &cgh) { auto accessor_ap = ap.get_access(cgh); auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_ztpsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.get_pointer(), - accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_ap.GET_MULTI_PTR, + accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } @@ -919,22 +916,21 @@ void trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag auto accessor_b = b.get_access(cgh); host_task(cgh, [=]() { ::cblas_strmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.get_pointer(), - (const int)lda, accessor_b.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.GET_MULTI_PTR, + (const int)lda, accessor_b.GET_MULTI_PTR, (const int)incx); }); }); } void trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag, int64_t n, - sycl::buffer &a, int64_t lda, sycl::buffer &b, - int64_t incx) { + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t incx) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); host_task(cgh, [=]() { ::cblas_dtrmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.get_pointer(), - (const int)lda, accessor_b.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.GET_MULTI_PTR, + (const int)lda, accessor_b.GET_MULTI_PTR, (const int)incx); }); }); } @@ -947,8 +943,8 @@ void trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag auto accessor_b = b.get_access(cgh); host_task(cgh, [=]() { ::cblas_ctrmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.get_pointer(), - (const int)lda, accessor_b.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.GET_MULTI_PTR, + (const int)lda, accessor_b.GET_MULTI_PTR, (const int)incx); }); }); } @@ -961,8 +957,8 @@ void trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag auto accessor_b = b.get_access(cgh); host_task(cgh, [=]() { ::cblas_ztrmv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.get_pointer(), - (const int)lda, accessor_b.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.GET_MULTI_PTR, + (const int)lda, accessor_b.GET_MULTI_PTR, (const int)incx); }); }); } @@ -974,22 +970,21 @@ void trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_strsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.get_pointer(), - (const int)lda, accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.GET_MULTI_PTR, + (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } void trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, - sycl::buffer &a, int64_t lda, sycl::buffer &x, - int64_t incx) { + sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_dtrsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.get_pointer(), - (const int)lda, accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.GET_MULTI_PTR, + (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } @@ -1002,8 +997,8 @@ void trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_ctrsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.get_pointer(), - (const int)lda, accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.GET_MULTI_PTR, + (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } @@ -1016,18 +1011,17 @@ void trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, auto accessor_x = x.get_access(cgh); host_task(cgh, [=]() { ::cblas_ztrsv(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.get_pointer(), - (const int)lda, accessor_x.get_pointer(), (const int)incx); + convert_to_cblas_diag(unit_diag), (const int)n, accessor_a.GET_MULTI_PTR, + (const int)lda, accessor_x.GET_MULTI_PTR, (const int)incx); }); }); } // USM APIs -sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, - int64_t ku, float alpha, const float *a, int64_t lda, const float *x, - int64_t incx, float beta, float *y, int64_t incy, - const std::vector &dependencies) { +sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, int64_t ku, + float alpha, const float *a, int64_t lda, const float *x, int64_t incx, float beta, + float *y, int64_t incy, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1042,10 +1036,10 @@ sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int6 return done; } -sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, - int64_t ku, double alpha, const double *a, int64_t lda, const double *x, - int64_t incx, double beta, double *y, int64_t incy, - const std::vector &dependencies) { +sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, int64_t ku, + double alpha, const double *a, int64_t lda, const double *x, int64_t incx, + double beta, double *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1060,11 +1054,11 @@ sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int6 return done; } -sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, - int64_t ku, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *x, int64_t incx, - std::complex beta, std::complex *y, int64_t incy, - const std::vector &dependencies) { +sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, int64_t ku, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1079,11 +1073,11 @@ sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int6 return done; } -sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, - int64_t ku, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *x, int64_t incx, - std::complex beta, std::complex *y, int64_t incy, - const std::vector &dependencies) { +sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, int64_t ku, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1099,8 +1093,8 @@ sycl::event gbmv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, int6 } sycl::event gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, const float *x, int64_t incx, float beta, - float *y, int64_t incy, const std::vector &dependencies) { + const float *a, int64_t lda, const float *x, int64_t incx, float beta, float *y, + int64_t incy, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1116,8 +1110,8 @@ sycl::event gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, floa } sycl::event gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, const double *x, int64_t incx, double beta, - double *y, int64_t incy, const std::vector &dependencies) { + const double *a, int64_t lda, const double *x, int64_t incx, double beta, + double *y, int64_t incy, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1133,10 +1127,10 @@ sycl::event gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, doub } sycl::event gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - const std::complex *x, int64_t incx, std::complex beta, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1152,10 +1146,10 @@ sycl::event gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, } sycl::event gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - const std::complex *x, int64_t incx, std::complex beta, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1170,9 +1164,9 @@ sycl::event gemv(sycl::queue &queue, transpose trans, int64_t m, int64_t n, return done; } -sycl::event ger(sycl::queue &queue, int64_t m, int64_t n, float alpha, const float *x, - int64_t incx, const float *y, int64_t incy, float *a, int64_t lda, - const std::vector &dependencies) { +sycl::event ger(sycl::queue &queue, int64_t m, int64_t n, float alpha, const float *x, int64_t incx, + const float *y, int64_t incy, float *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1187,8 +1181,8 @@ sycl::event ger(sycl::queue &queue, int64_t m, int64_t n, float alpha, const flo } sycl::event ger(sycl::queue &queue, int64_t m, int64_t n, double alpha, const double *x, - int64_t incx, const double *y, int64_t incy, double *a, int64_t lda, - const std::vector &dependencies) { + int64_t incx, const double *y, int64_t incy, double *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1203,9 +1197,9 @@ sycl::event ger(sycl::queue &queue, int64_t m, int64_t n, double alpha, const do } sycl::event gerc(sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, const std::complex *y, - int64_t incy, std::complex *a, int64_t lda, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1220,9 +1214,9 @@ sycl::event gerc(sycl::queue &queue, int64_t m, int64_t n, std::complex a } sycl::event gerc(sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, const std::complex *y, - int64_t incy, std::complex *a, int64_t lda, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1237,9 +1231,9 @@ sycl::event gerc(sycl::queue &queue, int64_t m, int64_t n, std::complex } sycl::event geru(sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, const std::complex *y, - int64_t incy, std::complex *a, int64_t lda, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1254,9 +1248,9 @@ sycl::event geru(sycl::queue &queue, int64_t m, int64_t n, std::complex a } sycl::event geru(sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, const std::complex *y, - int64_t incy, std::complex *a, int64_t lda, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1271,10 +1265,10 @@ sycl::event geru(sycl::queue &queue, int64_t m, int64_t n, std::complex } sycl::event hbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - const std::complex *x, int64_t incx, std::complex beta, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1290,10 +1284,10 @@ sycl::event hbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, } sycl::event hbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - const std::complex *x, int64_t incx, std::complex beta, - std::complex *y, int64_t incy, - const std::vector &dependencies) { + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1309,9 +1303,9 @@ sycl::event hbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, } sycl::event hemv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, const std::complex *x, - int64_t incx, std::complex beta, std::complex *y, int64_t incy, - const std::vector &dependencies) { + const std::complex *a, int64_t lda, const std::complex *x, + int64_t incx, std::complex beta, std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1326,11 +1320,10 @@ sycl::event hemv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, const std::complex *a, int64_t lda, - const std::complex *x, int64_t incx, std::complex beta, - std::complex *y, int64_t incy, - const std::vector &dependencies) { +sycl::event hemv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, const std::complex *x, + int64_t incx, std::complex beta, std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1346,8 +1339,8 @@ sycl::event hemv(sycl::queue &queue, uplo upper_lower, int64_t n, } sycl::event her(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - const std::complex *x, int64_t incx, std::complex *a, int64_t lda, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, std::complex *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1362,8 +1355,8 @@ sycl::event her(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, } sycl::event her(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - const std::complex *x, int64_t incx, std::complex *a, - int64_t lda, const std::vector &dependencies) { + const std::complex *x, int64_t incx, std::complex *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1378,9 +1371,9 @@ sycl::event her(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, } sycl::event her2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, const std::complex *y, - int64_t incy, std::complex *a, int64_t lda, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1395,10 +1388,10 @@ sycl::event her2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, const std::complex *x, int64_t incx, - const std::complex *y, int64_t incy, std::complex *a, - int64_t lda, const std::vector &dependencies) { +sycl::event her2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1414,9 +1407,9 @@ sycl::event her2(sycl::queue &queue, uplo upper_lower, int64_t n, } sycl::event hpmv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, - const std::complex *ap, const std::complex *x, int64_t incx, - std::complex beta, std::complex *y, int64_t incy, - const std::vector &dependencies) { + const std::complex *ap, const std::complex *x, int64_t incx, + std::complex beta, std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1431,11 +1424,10 @@ sycl::event hpmv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, const std::complex *ap, - const std::complex *x, int64_t incx, std::complex beta, - std::complex *y, int64_t incy, - const std::vector &dependencies) { +sycl::event hpmv(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, + const std::complex *ap, const std::complex *x, int64_t incx, + std::complex beta, std::complex *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1451,8 +1443,8 @@ sycl::event hpmv(sycl::queue &queue, uplo upper_lower, int64_t n, } sycl::event hpr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - const std::complex *x, int64_t incx, std::complex *ap, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, std::complex *ap, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1467,8 +1459,8 @@ sycl::event hpr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, } sycl::event hpr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - const std::complex *x, int64_t incx, std::complex *ap, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, std::complex *ap, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1483,9 +1475,9 @@ sycl::event hpr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, } sycl::event hpr2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, const std::complex *y, - int64_t incy, std::complex *ap, - const std::vector &dependencies) { + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *ap, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1499,10 +1491,10 @@ sycl::event hpr2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, const std::complex *x, int64_t incx, - const std::complex *y, int64_t incy, std::complex *ap, - const std::vector &dependencies) { +sycl::event hpr2(sycl::queue &queue, uplo upper_lower, int64_t n, std::complex alpha, + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *ap, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1517,8 +1509,8 @@ sycl::event hpr2(sycl::queue &queue, uplo upper_lower, int64_t n, } sycl::event sbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, float alpha, - const float *a, int64_t lda, const float *x, int64_t incx, float beta, - float *y, int64_t incy, const std::vector &dependencies) { + const float *a, int64_t lda, const float *x, int64_t incx, float beta, float *y, + int64_t incy, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1534,8 +1526,8 @@ sycl::event sbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, flo } sycl::event sbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, double alpha, - const double *a, int64_t lda, const double *x, int64_t incx, double beta, - double *y, int64_t incy, const std::vector &dependencies) { + const double *a, int64_t lda, const double *x, int64_t incx, double beta, + double *y, int64_t incy, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1550,9 +1542,9 @@ sycl::event sbmv(sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, dou return done; } -sycl::event spmv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - const float *ap, const float *x, int64_t incx, float beta, float *y, - int64_t incy, const std::vector &dependencies) { +sycl::event spmv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, const float *ap, + const float *x, int64_t incx, float beta, float *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1567,9 +1559,9 @@ sycl::event spmv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, return done; } -sycl::event spmv(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - const double *ap, const double *x, int64_t incx, double beta, double *y, - int64_t incy, const std::vector &dependencies) { +sycl::event spmv(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, const double *ap, + const double *x, int64_t incx, double beta, double *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1584,9 +1576,8 @@ sycl::event spmv(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, return done; } -sycl::event spr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - const float *x, int64_t incx, float *ap, - const std::vector &dependencies) { +sycl::event spr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, const float *x, + int64_t incx, float *ap, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1600,9 +1591,8 @@ sycl::event spr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, return done; } -sycl::event spr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - const double *x, int64_t incx, double *ap, - const std::vector &dependencies) { +sycl::event spr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, const double *x, + int64_t incx, double *ap, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1616,9 +1606,9 @@ sycl::event spr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, return done; } -sycl::event spr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - const float *x, int64_t incx, const float *y, int64_t incy, float *ap, - const std::vector &dependencies) { +sycl::event spr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, const float *x, + int64_t incx, const float *y, int64_t incy, float *ap, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1632,9 +1622,9 @@ sycl::event spr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, return done; } -sycl::event spr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - const double *x, int64_t incx, const double *y, int64_t incy, double *ap, - const std::vector &dependencies) { +sycl::event spr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, const double *x, + int64_t incx, const double *y, int64_t incy, double *ap, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1648,9 +1638,9 @@ sycl::event spr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, return done; } -sycl::event symv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - const float *a, int64_t lda, const float *x, int64_t incx, float beta, - float *y, int64_t incy, const std::vector &dependencies) { +sycl::event symv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, const float *a, + int64_t lda, const float *x, int64_t incx, float beta, float *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1665,9 +1655,9 @@ sycl::event symv(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, return done; } -sycl::event symv(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - const double *a, int64_t lda, const double *x, int64_t incx, double beta, - double *y, int64_t incy, const std::vector &dependencies) { +sycl::event symv(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, const double *a, + int64_t lda, const double *x, int64_t incx, double beta, double *y, int64_t incy, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1682,9 +1672,8 @@ sycl::event symv(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, return done; } -sycl::event syr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - const float *x, int64_t incx, float *a, int64_t lda, - const std::vector &dependencies) { +sycl::event syr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, const float *x, + int64_t incx, float *a, int64_t lda, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1698,9 +1687,9 @@ sycl::event syr(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, return done; } -sycl::event syr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - const double *x, int64_t incx, double *a, int64_t lda, - const std::vector &dependencies) { +sycl::event syr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, const double *x, + int64_t incx, double *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1714,9 +1703,9 @@ sycl::event syr(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, return done; } -sycl::event syr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, - const float *x, int64_t incx, const float *y, int64_t incy, float *a, - int64_t lda, const std::vector &dependencies) { +sycl::event syr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, const float *x, + int64_t incx, const float *y, int64_t incy, float *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1731,9 +1720,9 @@ sycl::event syr2(sycl::queue &queue, uplo upper_lower, int64_t n, float alpha, return done; } -sycl::event syr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, - const double *x, int64_t incx, const double *y, int64_t incy, double *a, - int64_t lda, const std::vector &dependencies) { +sycl::event syr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, const double *x, + int64_t incx, const double *y, int64_t incy, double *a, int64_t lda, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1748,9 +1737,9 @@ sycl::event syr2(sycl::queue &queue, uplo upper_lower, int64_t n, double alpha, return done; } -sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, int64_t k, const float *a, int64_t lda, float *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + int64_t k, const float *a, int64_t lda, float *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1765,9 +1754,9 @@ sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, int64_t k, const double *a, int64_t lda, double *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + int64_t k, const double *a, int64_t lda, double *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1782,10 +1771,9 @@ sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, int64_t k, const std::complex *a, int64_t lda, - std::complex *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + int64_t k, const std::complex *a, int64_t lda, std::complex *x, + int64_t incx, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1800,10 +1788,9 @@ sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, int64_t k, const std::complex *a, int64_t lda, - std::complex *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + int64_t k, const std::complex *a, int64_t lda, std::complex *x, + int64_t incx, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1818,9 +1805,9 @@ sycl::event tbmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, int64_t k, const float *a, int64_t lda, float *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + int64_t k, const float *a, int64_t lda, float *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1835,9 +1822,9 @@ sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, int64_t k, const double *a, int64_t lda, double *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + int64_t k, const double *a, int64_t lda, double *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1852,10 +1839,9 @@ sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, int64_t k, const std::complex *a, int64_t lda, - std::complex *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + int64_t k, const std::complex *a, int64_t lda, std::complex *x, + int64_t incx, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1870,10 +1856,9 @@ sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, int64_t k, const std::complex *a, int64_t lda, - std::complex *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + int64_t k, const std::complex *a, int64_t lda, std::complex *x, + int64_t incx, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1888,9 +1873,9 @@ sycl::event tbsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const float *ap, float *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const float *ap, float *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1904,9 +1889,9 @@ sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const double *ap, double *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const double *ap, double *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1920,9 +1905,9 @@ sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const std::complex *ap, std::complex *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const std::complex *ap, std::complex *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1936,9 +1921,9 @@ sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const std::complex *ap, std::complex *x, - int64_t incx, const std::vector &dependencies) { +sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const std::complex *ap, std::complex *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1952,9 +1937,9 @@ sycl::event tpmv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const float *ap, float *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const float *ap, float *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1968,9 +1953,9 @@ sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const double *ap, double *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const double *ap, double *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1984,9 +1969,9 @@ sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const std::complex *ap, std::complex *x, int64_t incx, - const std::vector &dependencies) { +sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const std::complex *ap, std::complex *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -2000,9 +1985,9 @@ sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const std::complex *ap, std::complex *x, - int64_t incx, const std::vector &dependencies) { +sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const std::complex *ap, std::complex *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -2016,9 +2001,9 @@ sycl::event tpsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag, - int64_t n, const float *a, int64_t lda, float *b, int64_t incx, - const std::vector &dependencies) { +sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag, int64_t n, + const float *a, int64_t lda, float *b, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -2033,9 +2018,9 @@ sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag un return done; } -sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag, - int64_t n, const double *a, int64_t lda, double *b, int64_t incx, - const std::vector &dependencies) { +sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag, int64_t n, + const double *a, int64_t lda, double *b, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -2050,9 +2035,9 @@ sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag un return done; } -sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag, - int64_t n, const std::complex *a, int64_t lda, std::complex *b, - int64_t incx, const std::vector &dependencies) { +sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag, int64_t n, + const std::complex *a, int64_t lda, std::complex *b, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -2067,9 +2052,9 @@ sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag un return done; } -sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag, - int64_t n, const std::complex *a, int64_t lda, std::complex *b, - int64_t incx, const std::vector &dependencies) { +sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag unit_diag, int64_t n, + const std::complex *a, int64_t lda, std::complex *b, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -2084,9 +2069,9 @@ sycl::event trmv(sycl::queue &queue, uplo upper_lower, transpose transa, diag un return done; } -sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const float *a, int64_t lda, float *x, int64_t incx, - const std::vector &dependencies) { +sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const float *a, int64_t lda, float *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -2101,9 +2086,9 @@ sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const double *a, int64_t lda, double *x, int64_t incx, - const std::vector &dependencies) { +sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const double *a, int64_t lda, double *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -2118,9 +2103,9 @@ sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const std::complex *a, int64_t lda, std::complex *x, - int64_t incx, const std::vector &dependencies) { +sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const std::complex *a, int64_t lda, std::complex *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -2135,9 +2120,9 @@ sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag uni return done; } -sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, - int64_t n, const std::complex *a, int64_t lda, std::complex *x, - int64_t incx, const std::vector &dependencies) { +sycl::event trsv(sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, + const std::complex *a, int64_t lda, std::complex *x, int64_t incx, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { diff --git a/src/blas/backends/netlib/netlib_level3.cxx b/src/blas/backends/netlib/netlib_level3.cxx index 30ba631a4..8bb6a04ae 100644 --- a/src/blas/backends/netlib/netlib_level3.cxx +++ b/src/blas/backends/netlib/netlib_level3.cxx @@ -19,10 +19,9 @@ // Buffer APIs -void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, float alpha, sycl::buffer &a, int64_t lda, - sycl::buffer &b, int64_t ldb, float beta, sycl::buffer &c, - int64_t ldc) { +void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, + float alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb, float beta, sycl::buffer &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -30,17 +29,16 @@ void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int host_task(cgh, [=]() { ::cblas_sgemm(MAJOR, convert_to_cblas_trans(transa), convert_to_cblas_trans(transb), (const int)m, (const int)n, (const int)k, (const float)alpha, - accessor_a.get_pointer(), (const int)lda, accessor_b.get_pointer(), - (const int)ldb, (const float)beta, accessor_c.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_b.GET_MULTI_PTR, + (const int)ldb, (const float)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, double alpha, sycl::buffer &a, int64_t lda, - sycl::buffer &b, int64_t ldb, double beta, sycl::buffer &c, - int64_t ldc) { +void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, + double alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb, double beta, sycl::buffer &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -48,17 +46,17 @@ void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int host_task(cgh, [=]() { ::cblas_dgemm(MAJOR, convert_to_cblas_trans(transa), convert_to_cblas_trans(transb), (const int)m, (const int)n, (const int)k, (const double)alpha, - accessor_a.get_pointer(), (const int)lda, accessor_b.get_pointer(), - (const int)ldb, (const double)beta, accessor_c.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_b.GET_MULTI_PTR, + (const int)ldb, (const double)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, sycl::buffer, 1> &b, int64_t ldb, - std::complex beta, sycl::buffer, 1> &c, int64_t ldc) { +void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + sycl::buffer, 1> &b, int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -66,17 +64,17 @@ void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int host_task(cgh, [=]() { ::cblas_cgemm(MAJOR, convert_to_cblas_trans(transa), convert_to_cblas_trans(transb), (const int)m, (const int)n, (const int)k, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, accessor_b.get_pointer(), - (const int)ldb, (const void *)&beta, accessor_c.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_b.GET_MULTI_PTR, + (const int)ldb, (const void *)&beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, sycl::buffer, 1> &b, int64_t ldb, - std::complex beta, sycl::buffer, 1> &c, int64_t ldc) { +void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + sycl::buffer, 1> &b, int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -84,15 +82,15 @@ void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int host_task(cgh, [=]() { ::cblas_zgemm(MAJOR, convert_to_cblas_trans(transa), convert_to_cblas_trans(transb), (const int)m, (const int)n, (const int)k, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, accessor_b.get_pointer(), - (const int)ldb, (const void *)&beta, accessor_c.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_b.GET_MULTI_PTR, + (const int)ldb, (const void *)&beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, sycl::half alpha, sycl::buffer &a, int64_t lda, +void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, + sycl::half alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, sycl::half beta, sycl::buffer &c, int64_t ldc) { #ifdef COLUMN_MAJOR @@ -103,10 +101,9 @@ void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int #endif } -void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, float alpha, sycl::buffer &a, int64_t lda, - sycl::buffer &b, int64_t ldb, float beta, - sycl::buffer &c, int64_t ldc) { +void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, + float alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb, float beta, sycl::buffer &c, int64_t ldc) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "gemm", "for column_major layout"); #endif @@ -115,10 +112,9 @@ void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int #endif } -void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, float alpha, sycl::buffer &a, int64_t lda, - sycl::buffer &b, int64_t ldb, float beta, sycl::buffer &c, - int64_t ldc) { +void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, + float alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb, float beta, sycl::buffer &c, int64_t ldc) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "gemm", "for column_major layout"); #endif @@ -138,9 +134,9 @@ void hemm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int6 host_task(cgh, [=]() { ::cblas_chemm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb, (const void *)&beta, - accessor_c.get_pointer(), (const int)ldc); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb, (const void *)&beta, + accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } @@ -156,38 +152,38 @@ void hemm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int6 host_task(cgh, [=]() { ::cblas_zhemm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb, (const void *)&beta, - accessor_c.get_pointer(), (const int)ldc); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb, (const void *)&beta, + accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, sycl::buffer, 1> &a, int64_t lda, float beta, +void herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, float alpha, + sycl::buffer, 1> &a, int64_t lda, float beta, sycl::buffer, 1> &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_c = c.get_access(cgh); host_task(cgh, [=]() { ::cblas_cherk(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - (const int)n, (const int)k, (const float)alpha, accessor_a.get_pointer(), - (const int)lda, (const float)beta, accessor_c.get_pointer(), + (const int)n, (const int)k, (const float)alpha, accessor_a.GET_MULTI_PTR, + (const int)lda, (const float)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, sycl::buffer, 1> &a, int64_t lda, double beta, +void herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, double alpha, + sycl::buffer, 1> &a, int64_t lda, double beta, sycl::buffer, 1> &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_c = c.get_access(cgh); host_task(cgh, [=]() { ::cblas_zherk(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - (const int)n, (const int)k, (const double)alpha, accessor_a.get_pointer(), - (const int)lda, (const double)beta, accessor_c.get_pointer(), + (const int)n, (const int)k, (const double)alpha, accessor_a.GET_MULTI_PTR, + (const int)lda, (const double)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); @@ -204,8 +200,8 @@ void her2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int host_task(cgh, [=]() { ::cblas_cher2k(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), (const int)n, (const int)k, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, accessor_b.get_pointer(), - (const int)ldb, (const float)beta, accessor_c.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_b.GET_MULTI_PTR, + (const int)ldb, (const float)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); @@ -222,16 +218,16 @@ void her2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int host_task(cgh, [=]() { ::cblas_zher2k(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), (const int)n, (const int)k, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, accessor_b.get_pointer(), - (const int)ldb, (const double)beta, accessor_c.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_b.GET_MULTI_PTR, + (const int)ldb, (const double)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, - int64_t ldb, float beta, sycl::buffer &c, int64_t ldc) { +void symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, float alpha, + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, + float beta, sycl::buffer &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -239,16 +235,16 @@ void symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int6 host_task(cgh, [=]() { ::cblas_ssymm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), (const int)m, (const int)n, - (const float)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb, (const float)beta, - accessor_c.get_pointer(), (const int)ldc); + (const float)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb, (const float)beta, + accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, - int64_t ldb, double beta, sycl::buffer &c, int64_t ldc) { +void symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, double alpha, + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, + double beta, sycl::buffer &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -256,9 +252,9 @@ void symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int6 host_task(cgh, [=]() { ::cblas_dsymm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), (const int)m, (const int)n, - (const double)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb, (const double)beta, - accessor_c.get_pointer(), (const int)ldc); + (const double)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb, (const double)beta, + accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } @@ -274,9 +270,9 @@ void symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int6 host_task(cgh, [=]() { ::cblas_csymm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb, (const void *)&beta, - accessor_c.get_pointer(), (const int)ldc); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb, (const void *)&beta, + accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } @@ -292,38 +288,38 @@ void symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int6 host_task(cgh, [=]() { ::cblas_zsymm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb, (const void *)&beta, - accessor_c.get_pointer(), (const int)ldc); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb, (const void *)&beta, + accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, sycl::buffer &a, int64_t lda, float beta, - sycl::buffer &c, int64_t ldc) { +void syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, float alpha, + sycl::buffer &a, int64_t lda, float beta, sycl::buffer &c, + int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_c = c.get_access(cgh); host_task(cgh, [=]() { ::cblas_ssyrk(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - (const int)n, (const int)k, (const float)alpha, accessor_a.get_pointer(), - (const int)lda, (const float)beta, accessor_c.get_pointer(), + (const int)n, (const int)k, (const float)alpha, accessor_a.GET_MULTI_PTR, + (const int)lda, (const float)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, sycl::buffer &a, int64_t lda, double beta, - sycl::buffer &c, int64_t ldc) { +void syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, double alpha, + sycl::buffer &a, int64_t lda, double beta, sycl::buffer &c, + int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_c = c.get_access(cgh); host_task(cgh, [=]() { ::cblas_dsyrk(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - (const int)n, (const int)k, (const double)alpha, accessor_a.get_pointer(), - (const int)lda, (const double)beta, accessor_c.get_pointer(), + (const int)n, (const int)k, (const double)alpha, accessor_a.GET_MULTI_PTR, + (const int)lda, (const double)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); @@ -338,8 +334,8 @@ void syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int6 host_task(cgh, [=]() { ::cblas_csyrk(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), (const int)n, (const int)k, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, (const void *)&beta, - accessor_c.get_pointer(), (const int)ldc); + accessor_a.GET_MULTI_PTR, (const int)lda, (const void *)&beta, + accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } @@ -353,32 +349,31 @@ void syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int6 host_task(cgh, [=]() { ::cblas_zsyrk(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), (const int)n, (const int)k, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, (const void *)&beta, - accessor_c.get_pointer(), (const int)ldc); + accessor_a.GET_MULTI_PTR, (const int)lda, (const void *)&beta, + accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, - int64_t ldb, float beta, sycl::buffer &c, int64_t ldc) { +void syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, float alpha, + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, + float beta, sycl::buffer &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); auto accessor_c = c.get_access(cgh); host_task(cgh, [=]() { ::cblas_ssyr2k(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), - (const int)n, (const int)k, (const float)alpha, accessor_a.get_pointer(), - (const int)lda, accessor_b.get_pointer(), (const int)ldb, - (const float)beta, accessor_c.get_pointer(), (const int)ldc); + (const int)n, (const int)k, (const float)alpha, accessor_a.GET_MULTI_PTR, + (const int)lda, accessor_b.GET_MULTI_PTR, (const int)ldb, + (const float)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } void syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, sycl::buffer &a, int64_t lda, - sycl::buffer &b, int64_t ldb, double beta, sycl::buffer &c, - int64_t ldc) { + double alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb, double beta, sycl::buffer &c, int64_t ldc) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -386,8 +381,8 @@ void syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int host_task(cgh, [=]() { ::cblas_dsyr2k(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), (const int)n, (const int)k, (const double)alpha, - accessor_a.get_pointer(), (const int)lda, accessor_b.get_pointer(), - (const int)ldb, (const double)beta, accessor_c.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_b.GET_MULTI_PTR, + (const int)ldb, (const double)beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); @@ -404,8 +399,8 @@ void syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int host_task(cgh, [=]() { ::cblas_csyr2k(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), (const int)n, (const int)k, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, accessor_b.get_pointer(), - (const int)ldb, (const void *)&beta, accessor_c.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_b.GET_MULTI_PTR, + (const int)ldb, (const void *)&beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); @@ -422,16 +417,16 @@ void syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int host_task(cgh, [=]() { ::cblas_zsyr2k(MAJOR, convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(trans), (const int)n, (const int)k, (const void *)&alpha, - accessor_a.get_pointer(), (const int)lda, accessor_b.get_pointer(), - (const int)ldb, (const void *)&beta, accessor_c.get_pointer(), + accessor_a.GET_MULTI_PTR, (const int)lda, accessor_b.GET_MULTI_PTR, + (const int)ldb, (const void *)&beta, accessor_c.GET_MULTI_PTR, (const int)ldc); }); }); } -void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, - int64_t lda, sycl::buffer &b, int64_t ldb) { +void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, diag unit_diag, + int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, + sycl::buffer &b, int64_t ldb) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -439,15 +434,15 @@ void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans ::cblas_strmm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), convert_to_cblas_diag(unit_diag), (const int)m, (const int)n, - (const float)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb); + (const float)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb); }); }); } -void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, double alpha, sycl::buffer &a, - int64_t lda, sycl::buffer &b, int64_t ldb) { +void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, diag unit_diag, + int64_t m, int64_t n, double alpha, sycl::buffer &a, int64_t lda, + sycl::buffer &b, int64_t ldb) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -455,16 +450,15 @@ void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans ::cblas_dtrmm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), convert_to_cblas_diag(unit_diag), (const int)m, (const int)n, - (const double)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb); + (const double)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb); }); }); } -void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { +void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, diag unit_diag, + int64_t m, int64_t n, std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, sycl::buffer, 1> &b, int64_t ldb) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -472,14 +466,14 @@ void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans ::cblas_ctrmm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), convert_to_cblas_diag(unit_diag), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb); }); }); } -void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, +void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, diag unit_diag, + int64_t m, int64_t n, std::complex alpha, sycl::buffer, 1> &a, int64_t lda, sycl::buffer, 1> &b, int64_t ldb) { queue.submit([&](sycl::handler &cgh) { @@ -489,15 +483,15 @@ void trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans ::cblas_ztrmm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), convert_to_cblas_diag(unit_diag), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb); }); }); } -void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, - int64_t lda, sycl::buffer &b, int64_t ldb) { +void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, diag unit_diag, + int64_t m, int64_t n, float alpha, sycl::buffer &a, int64_t lda, + sycl::buffer &b, int64_t ldb) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -505,15 +499,15 @@ void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans ::cblas_strsm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), convert_to_cblas_diag(unit_diag), (const int)m, (const int)n, - (const float)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb); + (const float)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb); }); }); } -void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, double alpha, sycl::buffer &a, - int64_t lda, sycl::buffer &b, int64_t ldb) { +void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, diag unit_diag, + int64_t m, int64_t n, double alpha, sycl::buffer &a, int64_t lda, + sycl::buffer &b, int64_t ldb) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -521,16 +515,15 @@ void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans ::cblas_dtrsm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), convert_to_cblas_diag(unit_diag), (const int)m, (const int)n, - (const double)alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb); + (const double)alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb); }); }); } -void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { +void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, diag unit_diag, + int64_t m, int64_t n, std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, sycl::buffer, 1> &b, int64_t ldb) { queue.submit([&](sycl::handler &cgh) { auto accessor_a = a.get_access(cgh); auto accessor_b = b.get_access(cgh); @@ -538,14 +531,14 @@ void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans ::cblas_ctrsm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), convert_to_cblas_diag(unit_diag), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb); }); }); } -void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, +void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, diag unit_diag, + int64_t m, int64_t n, std::complex alpha, sycl::buffer, 1> &a, int64_t lda, sycl::buffer, 1> &b, int64_t ldb) { queue.submit([&](sycl::handler &cgh) { @@ -555,18 +548,17 @@ void trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans ::cblas_ztrsm(MAJOR, convert_to_cblas_side(left_right), convert_to_cblas_uplo(upper_lower), convert_to_cblas_trans(transa), convert_to_cblas_diag(unit_diag), (const int)m, (const int)n, - (const void *)&alpha, accessor_a.get_pointer(), (const int)lda, - accessor_b.get_pointer(), (const int)ldb); + (const void *)&alpha, accessor_a.GET_MULTI_PTR, (const int)lda, + accessor_b.GET_MULTI_PTR, (const int)ldb); }); }); } // USM APIs -sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, float alpha, const float *a, int64_t lda, const float *b, - int64_t ldb, float beta, float *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const float *a, int64_t lda, const float *b, int64_t ldb, + float beta, float *c, int64_t ldc, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -581,10 +573,10 @@ sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t return done; } -sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, double alpha, const double *a, int64_t lda, - const double *b, int64_t ldb, double beta, double *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, double alpha, const double *a, int64_t lda, const double *b, + int64_t ldb, double beta, double *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -599,11 +591,11 @@ sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t return done; } -sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *b, int64_t ldb, - std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, std::complex beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -619,11 +611,11 @@ sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t return done; } -sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, std::complex alpha, - const std::complex *a, int64_t lda, const std::complex *b, - int64_t ldb, std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, std::complex beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -639,10 +631,10 @@ sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t return done; } -sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, sycl::half alpha, const sycl::half *a, int64_t lda, - const sycl::half *b, int64_t ldb, sycl::half beta, sycl::half *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, sycl::half alpha, const sycl::half *a, int64_t lda, const sycl::half *b, + int64_t ldb, sycl::half beta, sycl::half *c, int64_t ldc, + const std::vector &dependencies) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "gemm", "for column_major layout"); #endif @@ -651,10 +643,10 @@ sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t #endif } -sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, float alpha, const sycl::half *a, int64_t lda, - const sycl::half *b, int64_t ldb, float beta, float *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const sycl::half *a, int64_t lda, const sycl::half *b, + int64_t ldb, float beta, float *c, int64_t ldc, + const std::vector &dependencies) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "gemm", "for column_major layout"); #endif @@ -663,10 +655,10 @@ sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t #endif } -sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, float alpha, const bfloat16 *a, int64_t lda, - const bfloat16 *b, int64_t ldb, float beta, float *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, float alpha, const bfloat16 *a, int64_t lda, const bfloat16 *b, + int64_t ldb, float beta, float *c, int64_t ldc, + const std::vector &dependencies) { #ifdef COLUMN_MAJOR throw unimplemented("blas", "gemm", "for column_major layout"); #endif @@ -675,11 +667,11 @@ sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t #endif } -sycl::event hemm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *b, int64_t ldb, - std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event hemm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, std::complex beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -695,11 +687,11 @@ sycl::event hemm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t return done; } -sycl::event hemm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *b, int64_t ldb, - std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event hemm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, std::complex beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -715,10 +707,10 @@ sycl::event hemm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t return done; } -sycl::event herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, float alpha, const std::complex *a, int64_t lda, float beta, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + float alpha, const std::complex *a, int64_t lda, float beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -733,10 +725,10 @@ sycl::event herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, double alpha, const std::complex *a, int64_t lda, - double beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + double alpha, const std::complex *a, int64_t lda, double beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -751,11 +743,10 @@ sycl::event herk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event her2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *b, int64_t ldb, float beta, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event her2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, float beta, std::complex *c, + int64_t ldc, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -770,11 +761,10 @@ sycl::event her2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event her2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *b, int64_t ldb, double beta, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event her2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, double beta, std::complex *c, + int64_t ldc, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -789,10 +779,9 @@ sycl::event her2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, - int64_t n, float alpha, const float *a, int64_t lda, const float *b, - int64_t ldb, float beta, float *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, + float alpha, const float *a, int64_t lda, const float *b, int64_t ldb, float beta, + float *c, int64_t ldc, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -808,10 +797,10 @@ sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t return done; } -sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, - int64_t n, double alpha, const double *a, int64_t lda, const double *b, - int64_t ldb, double beta, double *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, + double alpha, const double *a, int64_t lda, const double *b, int64_t ldb, + double beta, double *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -827,11 +816,11 @@ sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t return done; } -sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *b, int64_t ldb, - std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, std::complex beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -847,11 +836,11 @@ sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t return done; } -sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *b, int64_t ldb, - std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, std::complex beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -867,9 +856,9 @@ sycl::event symm(sycl::queue &queue, side left_right, uplo upper_lower, int64_t return done; } -sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, float alpha, const float *a, int64_t lda, float beta, float *c, - int64_t ldc, const std::vector &dependencies) { +sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + float alpha, const float *a, int64_t lda, float beta, float *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -884,9 +873,9 @@ sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, double alpha, const double *a, int64_t lda, double beta, double *c, - int64_t ldc, const std::vector &dependencies) { +sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + double alpha, const double *a, int64_t lda, double beta, double *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -901,10 +890,10 @@ sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, std::complex alpha, const std::complex *a, - int64_t lda, std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + std::complex beta, std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -919,10 +908,10 @@ sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, std::complex alpha, const std::complex *a, - int64_t lda, std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + std::complex beta, std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -937,10 +926,9 @@ sycl::event syrk(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, float alpha, const float *a, int64_t lda, const float *b, - int64_t ldb, float beta, float *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + float alpha, const float *a, int64_t lda, const float *b, int64_t ldb, float beta, + float *c, int64_t ldc, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -955,10 +943,10 @@ sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, double alpha, const double *a, int64_t lda, const double *b, - int64_t ldb, double beta, double *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + double alpha, const double *a, int64_t lda, const double *b, int64_t ldb, + double beta, double *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -973,11 +961,11 @@ sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *b, int64_t ldb, - std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, std::complex beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -992,11 +980,11 @@ sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t return done; } -sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, - int64_t k, std::complex alpha, const std::complex *a, - int64_t lda, const std::complex *b, int64_t ldb, - std::complex beta, std::complex *c, int64_t ldc, - const std::vector &dependencies) { +sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *b, int64_t ldb, std::complex beta, + std::complex *c, int64_t ldc, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1012,8 +1000,8 @@ sycl::event syr2k(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t } sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, - float *b, int64_t ldb, const std::vector &dependencies) { + diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, + float *b, int64_t ldb, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1030,9 +1018,8 @@ sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpos } sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, - int64_t lda, double *b, int64_t ldb, - const std::vector &dependencies) { + diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, int64_t lda, + double *b, int64_t ldb, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1049,9 +1036,9 @@ sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpos } sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, std::complex *b, int64_t ldb, - const std::vector &dependencies) { + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, std::complex *b, int64_t ldb, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1068,9 +1055,9 @@ sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpos } sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, std::complex *b, - int64_t ldb, const std::vector &dependencies) { + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, std::complex *b, int64_t ldb, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1087,8 +1074,8 @@ sycl::event trmm(sycl::queue &queue, side left_right, uplo upper_lower, transpos } sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, - float *b, int64_t ldb, const std::vector &dependencies) { + diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, int64_t lda, + float *b, int64_t ldb, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1105,9 +1092,8 @@ sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpos } sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, - int64_t lda, double *b, int64_t ldb, - const std::vector &dependencies) { + diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, int64_t lda, + double *b, int64_t ldb, const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1124,9 +1110,9 @@ sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpos } sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, std::complex *b, int64_t ldb, - const std::vector &dependencies) { + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, std::complex *b, int64_t ldb, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1143,9 +1129,9 @@ sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpos } sycl::event trsm(sycl::queue &queue, side left_right, uplo upper_lower, transpose transa, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, std::complex *b, - int64_t ldb, const std::vector &dependencies) { + diag unit_diag, int64_t m, int64_t n, std::complex alpha, + const std::complex *a, int64_t lda, std::complex *b, int64_t ldb, + const std::vector &dependencies) { auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { diff --git a/src/blas/backends/portblas/CMakeLists.txt b/src/blas/backends/portblas/CMakeLists.txt new file mode 100644 index 000000000..03fddbb38 --- /dev/null +++ b/src/blas/backends/portblas/CMakeLists.txt @@ -0,0 +1,222 @@ +#========================================================================== +# Copyright (C) Codeplay Software Limited +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# For your convenience, a copy of the License has been included in this +# repository. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +#========================================================================= + +set(LIB_NAME onemkl_blas_portblas) +set(LIB_OBJ ${LIB_NAME}_obj) + +if(NOT DEFINED PORTBLAS_TUNING_TARGET) + option(PORTBLAS_TUNING_TARGET "Set a TUNING_TARGET for portBLAS" "") +endif() + +# Parse compiler flags and return a list of SYCL targets +# The list is empty if no targets are set +function(get_sycl_targets FLAGS) + string(REGEX MATCH "-fsycl-targets=[^ ]*" SYCL_TARGETS_FLAG "${FLAGS}") + string(REPLACE "-fsycl-targets=" "" SYCL_TARGETS "${SYCL_TARGETS_FLAG}") + string(REPLACE "," ";" SYCL_TARGETS "${SYCL_TARGETS}") + set(SYCL_TARGETS ${SYCL_TARGETS} PARENT_SCOPE) +endfunction(get_sycl_targets) + +# portBLAS supports tuning for some device types, but can only be compiled +# for one at a time currently. Work out which device to tune for based on the +# DPC++ target triple specified via -fsycl-targets +if(TARGET ONEMKL::SYCL::SYCL) + get_target_property(ONEMKL_COMPILE_OPTIONS ONEMKL::SYCL::SYCL INTERFACE_COMPILE_OPTIONS) +endif() +get_sycl_targets("${ONEMKL_COMPILE_OPTIONS}") +list(LENGTH SYCL_TARGETS NUM_TARGETS) +if(NUM_TARGETS EQUAL 0) + get_sycl_targets("${CMAKE_CXX_FLAGS}") + list(LENGTH SYCL_TARGETS NUM_TARGETS) +endif() + +if(PORTBLAS_TUNING_TARGET) + # Allow the user to manually enable a specific device type + # for tuned portBLAS configurations and sets sycl-target. + if(PORTBLAS_TUNING_TARGET STREQUAL "INTEL_CPU") + set(ENABLE_PORTBLAS_BACKEND_INTEL_CPU "ON" CACHE INTERNAL "") + set(PORTBLAS_TUNING_TARGET "") + target_compile_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=spir64_x86_64 -fsycl-unnamed-lambda) + target_link_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=spir64_x86_64) + elseif(PORTBLAS_TUNING_TARGET STREQUAL "INTEL_GPU") + set(ENABLE_PORTBLAS_BACKEND_INTEL_GPU "ON" CACHE INTERNAL "") + elseif(PORTBLAS_TUNING_TARGET STREQUAL "AMD_GPU") + set(ENABLE_PORTBLAS_BACKEND_AMD_GPU "ON" CACHE INTERNAL "") + if (is_dpcpp) + target_compile_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=amdgcn-amd-amdhsa -fsycl-unnamed-lambda + -Xsycl-target-backend --offload-arch=${HIP_TARGETS}) + target_link_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend --offload-arch=${HIP_TARGETS}) + else() + message(WARNING "Compiler is not supported." + " Unable to automatically set the required flags for the target '${PORTBLAS_TUNING_TARGET}'." + " Compilation may fail.") + endif() + elseif(PORTBLAS_TUNING_TARGET STREQUAL "NVIDIA_GPU") + set(ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU "ON" CACHE INTERNAL "") + if (is_dpcpp) + target_compile_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=nvptx64-nvidia-cuda -fsycl-unnamed-lambda) + target_link_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=nvptx64-nvidia-cuda) + if(DEFINED CUDA_TARGET) + target_compile_options(ONEMKL::SYCL::SYCL INTERFACE + -Xsycl-target-backend --cuda-gpu-arch=${CUDA_TARGET}) + target_link_options(ONEMKL::SYCL::SYCL INTERFACE + -Xsycl-target-backend --cuda-gpu-arch=${CUDA_TARGET}) + endif() + else() + message(WARNING "Compiler is not supported." + " Unable to automatically set the required flags for the target '${PORTBLAS_TUNING_TARGET}'." + " Compilation may fail.") + endif() + else() + message(FATAL_ERROR "Unsupported PORTBLAS_TUNING_TARGET: '${PORTBLAS_TUNING_TARGET}'") + endif() +elseif(NUM_TARGETS EQUAL 0) + # Enable portBLAS backend for all devices types + set(ENABLE_PORTBLAS_BACKEND_INTEL_CPU "ON" CACHE INTERNAL "") + set(ENABLE_PORTBLAS_BACKEND_INTEL_GPU "ON" CACHE INTERNAL "") + set(ENABLE_PORTBLAS_BACKEND_AMD_GPU "ON" CACHE INTERNAL "") + set(ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU "ON" CACHE INTERNAL "") +else() + # Try to automatically detect the PORTBLAS_TUNING_TARGET + foreach(SYCL_TARGET IN LISTS SYCL_TARGETS) + if(SYCL_TARGETS MATCHES "^intel_gpu" OR SYCL_TARGETS MATCHES "^spir64_gen") + set(ENABLE_PORTBLAS_BACKEND_INTEL_GPU "ON" CACHE INTERNAL "") + set(PORTBLAS_TUNING_TARGET "INTEL_GPU") + elseif(SYCL_TARGETS MATCHES "^spir64_x86_64") + set(ENABLE_PORTBLAS_BACKEND_INTEL_CPU "ON" CACHE INTERNAL "") + elseif(SYCL_TARGETS MATCHES "^spir64") + set(ENABLE_PORTBLAS_BACKEND_INTEL_CPU "ON" CACHE INTERNAL "") + set(ENABLE_PORTBLAS_BACKEND_INTEL_GPU "ON" CACHE INTERNAL "") + set(PORTBLAS_TUNING_TARGET "INTEL_GPU") + elseif(SYCL_TARGETS MATCHES "^amd_gpu" OR SYCL_TARGETS MATCHES "-amd-") + set(ENABLE_PORTBLAS_BACKEND_AMD_GPU "ON" CACHE INTERNAL "") + set(PORTBLAS_TUNING_TARGET "AMD_GPU") + elseif(SYCL_TARGETS MATCHES "^nvidia_gpu" OR SYCL_TARGETS MATCHES "-nvidia-") + set(ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU "ON" CACHE INTERNAL "") + set(PORTBLAS_TUNING_TARGET "NVIDIA_GPU") + endif() + endforeach() + # Currently portBLAS can only be tuned for one type of device. + if(NUM_TARGETS GREATER 1) + set(PORTBLAS_TUNING_TARGET "") + endif() +endif() + +if(PORTBLAS_TUNING_TARGET STREQUAL "INTEL_GPU") + message(STATUS "Tuning portBLAS for Intel GPU devices") +elseif(PORTBLAS_TUNING_TARGET STREQUAL "AMD_GPU") + message(STATUS "Tuning portBLAS for AMD GPU devices") +elseif(PORTBLAS_TUNING_TARGET STREQUAL "NVIDIA_GPU") + message(STATUS "Tuning portBLAS for Nvidia GPU devices") +else() + message(STATUS "portBLAS is not tuned for any device which can impact performance") +endif() + +# If find_package doesn't work, download portBLAS from Github. This is +# intended to make OneMKL easier to use. +message(STATUS "Looking for portBLAS") +find_package(PORTBLAS QUIET) +if (NOT PORTBLAS_FOUND) + message(STATUS "Looking for portBLAS - could not find portBLAS with PORTBLAS_DIR") + include(FetchContent) + set(INSTALL_HEADER_ONLY ON) + set(BLAS_BUILD_SAMPLES OFF) + set(BLAS_ENABLE_BENCHMARK OFF) + set(BLAS_ENABLE_TESTING OFF) + set(ENABLE_EXPRESSION_TESTS OFF) + if(NOT PORTBLAS_TUNING_TARGET) + set(PORTBLAS_TUNING_TARGET "DEFAULT") + endif() + # Following variable TUNING_TARGET will be used in portBLAS internal configuration + set(TUNING_TARGET ${PORTBLAS_TUNING_TARGET}) + set(BLAS_ENABLE_COMPLEX ON) + # Set the policy to forward variables to portBLAS configure step + set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + set(FETCHCONTENT_BASE_DIR "${CMAKE_BINARY_DIR}/deps") + FetchContent_Declare( + portBLAS + GIT_REPOSITORY https://github.com/codeplaysoftware/portBLAS + GIT_TAG main + ) + FetchContent_MakeAvailable(portblas) + message(STATUS "Looking for portBLAS - downloaded") + +else() + message(STATUS "Looking for portBLAS - found") + add_library(portblas ALIAS PORTBLAS::portblas) +endif() + +set(SOURCES + portblas_level1_double.cpp portblas_level1_float.cpp + portblas_level2_double.cpp portblas_level2_float.cpp + portblas_level3_double.cpp portblas_level3_float.cpp + portblas_level3_half.cpp portblas_level3_bfloat16.cpp + portblas_batch.cpp + $<$: portblas_wrappers.cpp>) +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemkl_backend_libs_blas ${LIB_NAME}) + +if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES}) +endif() + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL portblas) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON) + +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/blas/backends/portblas/portblas_batch.cpp b/src/blas/backends/portblas/portblas_batch.cpp new file mode 100644 index 000000000..65f0cd59e --- /dev/null +++ b/src/blas/backends/portblas/portblas_batch.cpp @@ -0,0 +1,57 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "portblas_common.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { +namespace column_major { + +#define COLUMN_MAJOR +constexpr bool is_column_major() { + return true; +} +#include "portblas_batch.cxx" +#undef COLUMN_MAJOR + +} // namespace column_major +namespace row_major { + +#define ROW_MAJOR +constexpr bool is_column_major() { + return false; +} +#include "portblas_batch.cxx" +#undef ROW_MAJOR + +} // namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi diff --git a/src/blas/backends/portblas/portblas_batch.cxx b/src/blas/backends/portblas/portblas_batch.cxx new file mode 100644 index 000000000..28c7ee5dc --- /dev/null +++ b/src/blas/backends/portblas/portblas_batch.cxx @@ -0,0 +1,1017 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +// Buffer APIs + +void syrk_batch(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, float beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", ""); +} + +void syrk_batch(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, double beta, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", ""); +} + +void syrk_batch(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", ""); +} + +void syrk_batch(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "syrk_batch", ""); +} + +void gemv_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, float beta, + sycl::buffer &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", ""); +} + +void gemv_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + double alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, double beta, + sycl::buffer &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", ""); +} + +void gemv_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &x, + std::int64_t incx, std::int64_t stridex, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", ""); +} + +void gemv_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &x, + std::int64_t incx, std::int64_t stridex, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + throw unimplemented("blas", "gemv_batch", ""); +} + +void dgmm_batch(sycl::queue &queue, oneapi::mkl::side left_right, std::int64_t m, std::int64_t n, + sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer &c, std::int64_t ldc, std::int64_t stridec, + std::int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", ""); +} + +void dgmm_batch(sycl::queue &queue, oneapi::mkl::side left_right, std::int64_t m, std::int64_t n, + sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer &c, std::int64_t ldc, std::int64_t stridec, + std::int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", ""); +} + +void dgmm_batch(sycl::queue &queue, oneapi::mkl::side left_right, std::int64_t m, std::int64_t n, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer, 1> &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer, 1> &c, std::int64_t ldc, std::int64_t stridec, + std::int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", ""); +} + +void dgmm_batch(sycl::queue &queue, oneapi::mkl::side left_right, std::int64_t m, std::int64_t n, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer, 1> &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer, 1> &c, std::int64_t ldc, std::int64_t stridec, + std::int64_t batch_size) { + throw unimplemented("blas", "dgmm_batch", ""); +} + +void axpy_batch(sycl::queue &queue, std::int64_t n, float alpha, sycl::buffer &x, + std::int64_t incx, std::int64_t stridex, sycl::buffer &y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size) { + CALL_PORTBLAS_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy, stridey, + batch_size); +} + +void axpy_batch(sycl::queue &queue, std::int64_t n, double alpha, sycl::buffer &x, + std::int64_t incx, std::int64_t stridex, sycl::buffer &y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size) { + CALL_PORTBLAS_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy, stridey, + batch_size); +} + +void axpy_batch(sycl::queue &queue, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer, 1> &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", ""); +} + +void axpy_batch(sycl::queue &queue, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, std::int64_t stridex, + sycl::buffer, 1> &y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size) { + throw unimplemented("blas", "axpy_batch", ""); +} +void copy_batch(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + std::int64_t stridex, sycl::buffer &y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size) { + throw unimplemented("blas", "copy_batch", ""); +} + +void copy_batch(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + std::int64_t stridex, sycl::buffer &y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size) { + throw unimplemented("blas", "copy_batch", ""); +} + +void copy_batch(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, std::int64_t stridex, sycl::buffer, 1> &y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size) { + throw unimplemented("blas", "copy_batch", ""); +} + +void copy_batch(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, std::int64_t stridex, sycl::buffer, 1> &y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size) { + throw unimplemented("blas", "copy_batch", ""); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + CALL_PORTBLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, double alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, double beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + CALL_PORTBLAS_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, lda, + stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, std::int64_t stride_b, + std::complex beta, sycl::buffer, 1> &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for complex"); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, std::int64_t stride_b, + std::complex beta, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for complex"); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + sycl::half beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for complex"); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); +} + +void gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "gemm_batch", " for unsupported dtype"); +} + +void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", ""); +} + +void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", ""); +} + +void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", ""); +} + +void trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer, 1> &b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size) { + throw unimplemented("blas", "trsm_batch", ""); +} + +void omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + CALL_PORTBLAS_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size) { + CALL_PORTBLAS_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b, ldb, + stride_b, batch_size); +} + +void omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", ""); +} + +void omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size) { + throw unimplemented("blas", "omatcopy_batch", ""); +} + +void imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", ""); +} + +void imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", ""); +} + +void imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", ""); +} + +void imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + sycl::buffer, 1> &ab, std::int64_t lda, std::int64_t ldb, + std::int64_t stride, std::int64_t batch_size) { + throw unimplemented("blas", "imatcopy_batch", ""); +} + +void omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, float beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + CALL_PORTBLAS_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, double alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, double beta, sycl::buffer &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + CALL_PORTBLAS_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda, stride_a, + beta, b, ldb, stride_b, c, ldc, stride_c, batch_size); +} + +void omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::int64_t stride_a, + std::complex beta, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", ""); +} + +void omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t stride_b, sycl::buffer, 1> &c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size) { + throw unimplemented("blas", "omatadd_batch", ""); +} + +// USM APIs + +sycl::event syrk_batch(sycl::queue &queue, oneapi::mkl::uplo *upper_lower, + oneapi::mkl::transpose *trans, std::int64_t *n, std::int64_t *k, + float *alpha, const float **a, std::int64_t *lda, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk_batch", " for USM"); +} + +sycl::event syrk_batch(sycl::queue &queue, oneapi::mkl::uplo *upper_lower, + oneapi::mkl::transpose *trans, std::int64_t *n, std::int64_t *k, + double *alpha, const double **a, std::int64_t *lda, double *beta, double **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk_batch", " for USM"); +} + +sycl::event syrk_batch(sycl::queue &queue, oneapi::mkl::uplo *upper_lower, + oneapi::mkl::transpose *trans, std::int64_t *n, std::int64_t *k, + std::complex *alpha, const std::complex **a, std::int64_t *lda, + std::complex *beta, std::complex **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk_batch", " for USM"); +} + +sycl::event syrk_batch(sycl::queue &queue, oneapi::mkl::uplo *upper_lower, + oneapi::mkl::transpose *trans, std::int64_t *n, std::int64_t *k, + std::complex *alpha, const std::complex **a, + std::int64_t *lda, std::complex *beta, std::complex **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk_batch", " for USM"); +} + +sycl::event syrk_batch(sycl::queue &queue, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk_batch", " for USM"); +} + +sycl::event syrk_batch(sycl::queue &queue, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double beta, + double *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk_batch", " for USM"); +} + +sycl::event syrk_batch(sycl::queue &queue, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk_batch", " for USM"); +} + +sycl::event syrk_batch(sycl::queue &queue, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, std::int64_t n, std::int64_t k, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, std::complex *c, + std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk_batch", " for USM"); +} + +sycl::event gemv_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, const float *x, std::int64_t incx, + std::int64_t stridex, float beta, float *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", " for USM"); +} + +sycl::event gemv_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stridea, const double *x, std::int64_t incx, + std::int64_t stridex, double beta, double *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", " for USM"); +} + +sycl::event gemv_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, const std::complex *x, + std::int64_t incx, std::int64_t stridex, std::complex beta, + std::complex *y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", " for USM"); +} + +sycl::event gemv_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stridea, const std::complex *x, + std::int64_t incx, std::int64_t stridex, std::complex beta, + std::complex *y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", " for USM"); +} + +sycl::event gemv_batch(sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *m, + std::int64_t *n, float *alpha, const float **a, std::int64_t *lda, + const float **x, std::int64_t *incx, float *beta, float **y, + std::int64_t *incy, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", " for USM"); +} + +sycl::event gemv_batch(sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *m, + std::int64_t *n, double *alpha, const double **a, std::int64_t *lda, + const double **x, std::int64_t *incx, double *beta, double **y, + std::int64_t *incy, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", " for USM"); +} + +sycl::event gemv_batch(sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *m, + std::int64_t *n, std::complex *alpha, const std::complex **a, + std::int64_t *lda, const std::complex **x, std::int64_t *incx, + std::complex *beta, std::complex **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", " for USM"); +} + +sycl::event gemv_batch(sycl::queue &queue, oneapi::mkl::transpose *trans, std::int64_t *m, + std::int64_t *n, std::complex *alpha, const std::complex **a, + std::int64_t *lda, const std::complex **x, std::int64_t *incx, + std::complex *beta, std::complex **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv_batch", " for USM"); +} + +sycl::event dgmm_batch(sycl::queue &queue, oneapi::mkl::side left_right, std::int64_t m, + std::int64_t n, const float *a, std::int64_t lda, std::int64_t stridea, + const float *x, std::int64_t incx, std::int64_t stridex, float *c, + std::int64_t ldc, std::int64_t stridec, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", " for USM"); +} + +sycl::event dgmm_batch(sycl::queue &queue, oneapi::mkl::side left_right, std::int64_t m, + std::int64_t n, const double *a, std::int64_t lda, std::int64_t stridea, + const double *x, std::int64_t incx, std::int64_t stridex, double *c, + std::int64_t ldc, std::int64_t stridec, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", " for USM"); +} + +sycl::event dgmm_batch(sycl::queue &queue, oneapi::mkl::side left_right, std::int64_t m, + std::int64_t n, const std::complex *a, std::int64_t lda, + std::int64_t stridea, const std::complex *x, std::int64_t incx, + std::int64_t stridex, std::complex *c, std::int64_t ldc, + std::int64_t stridec, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", " for USM"); +} + +sycl::event dgmm_batch(sycl::queue &queue, oneapi::mkl::side left_right, std::int64_t m, + std::int64_t n, const std::complex *a, std::int64_t lda, + std::int64_t stridea, const std::complex *x, std::int64_t incx, + std::int64_t stridex, std::complex *c, std::int64_t ldc, + std::int64_t stridec, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", " for USM"); +} + +sycl::event dgmm_batch(sycl::queue &queue, oneapi::mkl::side *left_right, std::int64_t *m, + std::int64_t *n, const float **a, std::int64_t *lda, const float **x, + std::int64_t *incx, float **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", " for USM"); +} + +sycl::event dgmm_batch(sycl::queue &queue, oneapi::mkl::side *left_right, std::int64_t *m, + std::int64_t *n, const double **a, std::int64_t *lda, const double **x, + std::int64_t *incx, double **c, std::int64_t *ldc, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", " for USM"); +} + +sycl::event dgmm_batch(sycl::queue &queue, oneapi::mkl::side *left_right, std::int64_t *m, + std::int64_t *n, const std::complex **a, std::int64_t *lda, + const std::complex **x, std::int64_t *incx, std::complex **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", " for USM"); +} + +sycl::event dgmm_batch(sycl::queue &queue, oneapi::mkl::side *left_right, std::int64_t *m, + std::int64_t *n, const std::complex **a, std::int64_t *lda, + const std::complex **x, std::int64_t *incx, std::complex **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "dgmm_batch", " for USM"); +} + +sycl::event axpy_batch(sycl::queue &queue, std::int64_t *n, float *alpha, const float **x, + std::int64_t *incx, float **y, std::int64_t *incy, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", " for USM"); +} + +sycl::event axpy_batch(sycl::queue &queue, std::int64_t *n, double *alpha, const double **x, + std::int64_t *incx, double **y, std::int64_t *incy, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", " for USM"); +} + +sycl::event axpy_batch(sycl::queue &queue, std::int64_t *n, std::complex *alpha, + const std::complex **x, std::int64_t *incx, std::complex **y, + std::int64_t *incy, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", " for USM"); +} + +sycl::event axpy_batch(sycl::queue &queue, std::int64_t *n, std::complex *alpha, + const std::complex **x, std::int64_t *incx, std::complex **y, + std::int64_t *incy, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", " for USM"); +} + +sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, float alpha, const float *x, + std::int64_t incx, std::int64_t stridex, float *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy, stridey, + batch_size, dependencies); +} + +sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, double alpha, const double *x, + std::int64_t incx, std::int64_t stridex, double *y, std::int64_t incy, + std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_axpy_batch, queue, n, alpha, x, incx, stridex, y, incy, stridey, + batch_size, dependencies); +} + +sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, std::complex alpha, + const std::complex *x, std::int64_t incx, std::int64_t stridex, + std::complex *y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", " for USM"); +} + +sycl::event axpy_batch(sycl::queue &queue, std::int64_t n, std::complex alpha, + const std::complex *x, std::int64_t incx, std::int64_t stridex, + std::complex *y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "axpy_batch", " for USM"); +} + +sycl::event copy_batch(sycl::queue &queue, std::int64_t *n, const float **x, std::int64_t *incx, + float **y, std::int64_t *incy, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", " for USM"); +} + +sycl::event copy_batch(sycl::queue &queue, std::int64_t *n, const double **x, std::int64_t *incx, + double **y, std::int64_t *incy, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", " for USM"); +} + +sycl::event copy_batch(sycl::queue &queue, std::int64_t *n, const std::complex **x, + std::int64_t *incx, std::complex **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", " for USM"); +} + +sycl::event copy_batch(sycl::queue &queue, std::int64_t *n, const std::complex **x, + std::int64_t *incx, std::complex **y, std::int64_t *incy, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", " for USM"); +} + +sycl::event copy_batch(sycl::queue &queue, std::int64_t n, const float *x, std::int64_t incx, + std::int64_t stridex, float *y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", " for USM"); +} + +sycl::event copy_batch(sycl::queue &queue, std::int64_t n, const double *x, std::int64_t incx, + std::int64_t stridex, double *y, std::int64_t incy, std::int64_t stridey, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", " for USM"); +} + +sycl::event copy_batch(sycl::queue &queue, std::int64_t n, const std::complex *x, + std::int64_t incx, std::int64_t stridex, std::complex *y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", " for USM"); +} + +sycl::event copy_batch(sycl::queue &queue, std::int64_t n, const std::complex *x, + std::int64_t incx, std::int64_t stridex, std::complex *y, + std::int64_t incy, std::int64_t stridey, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "copy_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const float **a, std::int64_t *lda, + const float **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, double *alpha, const double **a, std::int64_t *lda, + const double **b, std::int64_t *ldb, double *beta, double **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, std::complex *alpha, const std::complex **a, + std::int64_t *lda, const std::complex **b, std::int64_t *ldb, + std::complex *beta, std::complex **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, std::complex *alpha, const std::complex **a, + std::int64_t *lda, const std::complex **b, std::int64_t *ldb, + std::complex *beta, std::complex **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, sycl::half *alpha, const sycl::half **a, std::int64_t *lda, + const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const sycl::half **a, std::int64_t *lda, + const sycl::half **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose *transa, + oneapi::mkl::transpose *transb, std::int64_t *m, std::int64_t *n, + std::int64_t *k, float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, const float *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, + dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, double alpha, const double *a, std::int64_t lda, + std::int64_t stride_a, const double *b, std::int64_t ldb, + std::int64_t stride_b, double beta, double *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_gemm_strided_batched, queue, transa, transb, m, n, k, alpha, a, + lda, stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, + dependencies); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, const std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, const std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::complex beta, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, sycl::half alpha, const sycl::half *a, std::int64_t lda, + std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const sycl::half *a, std::int64_t lda, + std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, + std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event gemm_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::int64_t k, float alpha, const std::int8_t *a, std::int64_t lda, + std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm_batch", " for USM"); +} + +sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, + oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, std::int64_t stride_a, float *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", " for USM"); +} + +sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, + oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, std::int64_t stride_a, double *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", " for USM"); +} + +sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, + oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", " for USM"); +} + +sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side left_right, + oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", " for USM"); +} + +sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side *left_right, + oneapi::mkl::uplo *upper_lower, oneapi::mkl::transpose *trans, + oneapi::mkl::diag *unit_diag, std::int64_t *m, std::int64_t *n, float *alpha, + const float **a, std::int64_t *lda, float **b, std::int64_t *ldb, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", " for USM"); +} + +sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side *left_right, + oneapi::mkl::uplo *upper_lower, oneapi::mkl::transpose *trans, + oneapi::mkl::diag *unit_diag, std::int64_t *m, std::int64_t *n, + double *alpha, const double **a, std::int64_t *lda, double **b, + std::int64_t *ldb, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", " for USM"); +} + +sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side *left_right, + oneapi::mkl::uplo *upper_lower, oneapi::mkl::transpose *trans, + oneapi::mkl::diag *unit_diag, std::int64_t *m, std::int64_t *n, + std::complex *alpha, const std::complex **a, std::int64_t *lda, + std::complex **b, std::int64_t *ldb, std::int64_t group_count, + std::int64_t *group_size, const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", " for USM"); +} + +sycl::event trsm_batch(sycl::queue &queue, oneapi::mkl::side *left_right, + oneapi::mkl::uplo *upper_lower, oneapi::mkl::transpose *trans, + oneapi::mkl::diag *unit_diag, std::int64_t *m, std::int64_t *n, + std::complex *alpha, const std::complex **a, + std::int64_t *lda, std::complex **b, std::int64_t *ldb, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm_batch", " for USM"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stride_a, float *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b, + ldb, stride_b, batch_size, dependencies); +} + +sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, double alpha, const double *a, std::int64_t lda, + std::int64_t stride_a, double *b, std::int64_t ldb, + std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_omatcopy_batch, queue, trans, m, n, alpha, a, lda, stride_a, b, + ldb, stride_b, batch_size, dependencies); +} + +sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex *b, + std::int64_t ldb, std::int64_t stride_b, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", " for USM"); +} + +sycl::event omatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stride_a, + std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", " for USM"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, float alpha, float *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", " for USM"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, double alpha, double *ab, std::int64_t lda, + std::int64_t ldb, std::int64_t stride, std::int64_t batch_size, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", " for USM"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", " for USM"); +} + +sycl::event imatcopy_batch(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, std::complex *ab, + std::int64_t lda, std::int64_t ldb, std::int64_t stride, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", " for USM"); +} + +sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, + float beta, const float *b, std::int64_t ldb, std::int64_t stride_b, + float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda, + stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, + dependencies); +} + +sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + double alpha, const double *a, std::int64_t lda, std::int64_t stride_a, + double beta, const double *b, std::int64_t ldb, std::int64_t stride_b, + double *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_omatadd_batch, queue, transa, transb, m, n, alpha, a, lda, + stride_a, beta, b, ldb, stride_b, c, ldc, stride_c, batch_size, + dependencies); +} + +sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", " for USM"); +} + +sycl::event omatadd_batch(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, + std::int64_t lda, std::int64_t stride_a, std::complex beta, + const std::complex *b, std::int64_t ldb, std::int64_t stride_b, + std::complex *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + throw unimplemented("blas", "omatadd_batch", " for USM"); +} diff --git a/src/blas/backends/portblas/portblas_common.hpp b/src/blas/backends/portblas/portblas_common.hpp new file mode 100644 index 000000000..1624749e8 --- /dev/null +++ b/src/blas/backends/portblas/portblas_common.hpp @@ -0,0 +1,239 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _PORTBLAS_COMMON_HPP_ +#define _PORTBLAS_COMMON_HPP_ + +#include "portblas.hpp" +#include "oneapi/mkl/types.hpp" +#include "oneapi/mkl/exceptions.hpp" + +#include +#include + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { + +namespace detail { +// portBLAS handle type. Constructed with sycl::queue. +using handle_t = ::blas::SB_Handle; + +// portBLAS buffer iterator. Constructed with sycl::buffer +template +using buffer_iterator_t = ::blas::BufferIterator; + +// sycl complex data type (experimental) +template +using sycl_complex_t = sycl::ext::oneapi::experimental::complex; + +/** A trait for obtaining equivalent portBLAS API types from oneMKL API + * types. + * + * @tparam InputT is the oneMKL type. + * portblas_type::type should be the equivalent portBLAS type. +**/ +template +struct portblas_type; + +#define DEF_PORTBLAS_TYPE(onemkl_t, portblas_t) \ + template <> \ + struct portblas_type { \ + using type = portblas_t; \ + }; + +DEF_PORTBLAS_TYPE(sycl::queue, handle_t) +DEF_PORTBLAS_TYPE(int64_t, int64_t) +DEF_PORTBLAS_TYPE(sycl::half, sycl::half) +DEF_PORTBLAS_TYPE(float, float) +DEF_PORTBLAS_TYPE(double, double) +DEF_PORTBLAS_TYPE(oneapi::mkl::transpose, char) +DEF_PORTBLAS_TYPE(oneapi::mkl::uplo, char) +DEF_PORTBLAS_TYPE(oneapi::mkl::side, char) +DEF_PORTBLAS_TYPE(oneapi::mkl::diag, char) +DEF_PORTBLAS_TYPE(std::complex, sycl_complex_t) +DEF_PORTBLAS_TYPE(std::complex, sycl_complex_t) +// Passthrough of portBLAS arg types for more complex wrapping. +DEF_PORTBLAS_TYPE(::blas::gemm_batch_type_t, ::blas::gemm_batch_type_t) + +#undef DEF_PORTBLAS_TYPE + +template +struct portblas_type> { + using type = buffer_iterator_t; +}; + +template +struct portblas_type { + using type = ElemT*; +}; + +// USM Complex +template +struct portblas_type*> { + using type = sycl_complex_t*; +}; + +template +struct portblas_type*> { + using type = const sycl_complex_t*; +}; + +template <> +struct portblas_type> { + using type = std::vector; +}; + +/** Convert a OneMKL argument to the type required for portBLAS. + * + * @tparam InputT The OneMKL type. + * @param input The value of the oneMKL type. + * @return The portBLAS value with appropriate type. +**/ +template +inline typename portblas_type::type convert_to_portblas_type(InputT& input) { + return typename portblas_type::type(input); +} + +template <> +inline char convert_to_portblas_type(oneapi::mkl::transpose& trans) { + if (trans == oneapi::mkl::transpose::nontrans) { + return 'n'; + } + else if (trans == oneapi::mkl::transpose::trans) { + return 't'; + } + else { // trans == oneapi::mkl::transpose::conjtrans + return 'c'; + } +} + +template <> +inline char convert_to_portblas_type(oneapi::mkl::uplo& upper_lower) { + if (upper_lower == oneapi::mkl::uplo::upper) { + return 'u'; + } + else { + return 'l'; + } +} + +template <> +inline char convert_to_portblas_type(oneapi::mkl::side& left_right) { + if (left_right == oneapi::mkl::side::left) { + return 'l'; + } + else { + return 'r'; + } +} + +template <> +inline char convert_to_portblas_type(oneapi::mkl::diag& unit_diag) { + if (unit_diag == oneapi::mkl::diag::unit) { + return 'u'; + } + else { + return 'n'; + } +} + +template +inline auto convert_to_portblas_type(ArgT... args) { + return std::make_tuple(convert_to_portblas_type(args)...); +} + +/** Throw an unsupported_device exception if a certain argument type is found in + * the argument pack. + * + * @tparam CheckT is type to look for a template parameter pack. + * @tparam AspectVal is the device aspect required to support CheckT. +**/ +template +struct throw_if_unsupported_by_device { + /** Operator to throw if unsupported. + * + * @tparam ArgTs The argument types to check. + * @param The message to include in the exception. + * @param q is the sycl::queue. + * @param args are the remaining args to check for CheckT in. +**/ + template + void operator()(const std::string& message, sycl::queue q, ArgTs... args) { + static constexpr bool checkTypeInPack = (std::is_same_v || ...); + if (checkTypeInPack) { + if (!q.get_info().has(AspectVal)) { + throw mkl::unsupported_device("blas", message, + q.get_info()); + } + } + } +}; + +} // namespace detail + +#define CALL_PORTBLAS_FN(portBLASFunc, ...) \ + if constexpr (is_column_major()) { \ + detail::throw_if_unsupported_by_device, sycl::aspect::fp64>{}( \ + " portBLAS function requiring fp64 support", __VA_ARGS__); \ + detail::throw_if_unsupported_by_device, sycl::aspect::fp16>{}( \ + " portBLAS function requiring fp16 support", __VA_ARGS__); \ + auto args = detail::convert_to_portblas_type(__VA_ARGS__); \ + auto fn = [](auto&&... targs) { \ + portBLASFunc(std::forward(targs)...); \ + }; \ + try { \ + std::apply(fn, args); \ + } \ + catch (const ::blas::unsupported_exception& e) { \ + throw unimplemented("blas", e.what()); \ + } \ + } \ + else { \ + throw unimplemented("blas", "portBLAS function"); \ + } + +#define CALL_PORTBLAS_USM_FN(portblasFunc, ...) \ + if constexpr (is_column_major()) { \ + detail::throw_if_unsupported_by_device{}( \ + " portBLAS function requiring fp64 support", __VA_ARGS__); \ + detail::throw_if_unsupported_by_device{}( \ + " portBLAS function requiring fp16 support", __VA_ARGS__); \ + auto args = detail::convert_to_portblas_type(__VA_ARGS__); \ + auto fn = [](auto&&... targs) { \ + return portblasFunc(std::forward(targs)...).back(); \ + }; \ + try { \ + return std::apply(fn, args); \ + } \ + catch (const ::blas::unsupported_exception& e) { \ + throw unimplemented("blas", e.what()); \ + } \ + } \ + else { \ + throw unimplemented("blas", "portBLAS function"); \ + } + +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi + +#endif // _PORTBLAS_COMMON_HPP_ diff --git a/src/blas/backends/portblas/portblas_gemm_bias.cxx b/src/blas/backends/portblas/portblas_gemm_bias.cxx new file mode 100644 index 000000000..30f638f3e --- /dev/null +++ b/src/blas/backends/portblas/portblas_gemm_bias.cxx @@ -0,0 +1,90 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +// Buffer APIs + +void gemm_bias(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + oneapi::mkl::offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, int8_t ao, + sycl::buffer &b, std::int64_t ldb, uint8_t bo, float beta, + sycl::buffer &c, std::int64_t ldc, sycl::buffer &co) { + throw unimplemented("blas", "gemm_bias", ""); +} + +void gemm_bias(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + oneapi::mkl::offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, int8_t ao, + sycl::buffer &b, std::int64_t ldb, int8_t bo, float beta, + sycl::buffer &c, std::int64_t ldc, sycl::buffer &co) { + throw unimplemented("blas", "gemm_bias", ""); +} + +void gemm_bias(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + oneapi::mkl::offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, uint8_t ao, + sycl::buffer &b, std::int64_t ldb, int8_t bo, float beta, + sycl::buffer &c, std::int64_t ldc, sycl::buffer &co) { + throw unimplemented("blas", "gemm_bias", ""); +} + +void gemm_bias(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + oneapi::mkl::offset offsetc, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, sycl::buffer &a, std::int64_t lda, uint8_t ao, + sycl::buffer &b, std::int64_t ldb, uint8_t bo, float beta, + sycl::buffer &c, std::int64_t ldc, sycl::buffer &co) { + throw unimplemented("blas", "gemm_bias", ""); +} + +// USM APIs + +sycl::event gemm_bias(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, oneapi::mkl::offset offsetc, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int8_t ao, const std::uint8_t *b, std::int64_t ldb, + std::uint8_t bo, float beta, std::int32_t *c, std::int64_t ldc, + const std::int32_t *co, const std::vector &dependencies) { + throw unimplemented("blas", "gemm_bias", " for USM"); +} + +sycl::event gemm_bias(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, oneapi::mkl::offset offsetc, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int8_t ao, const std::int8_t *b, std::int64_t ldb, + std::int8_t bo, float beta, std::int32_t *c, std::int64_t ldc, + const std::int32_t *co, const std::vector &dependencies) { + throw unimplemented("blas", "gemm_bias", " for USM"); +} + +sycl::event gemm_bias(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, oneapi::mkl::offset offsetc, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::uint8_t *a, + std::int64_t lda, std::uint8_t ao, const std::int8_t *b, std::int64_t ldb, + std::int8_t bo, float beta, std::int32_t *c, std::int64_t ldc, + const std::int32_t *co, const std::vector &dependencies) { + throw unimplemented("blas", "gemm_bias", " for USM"); +} + +sycl::event gemm_bias(sycl::queue &queue, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, oneapi::mkl::offset offsetc, std::int64_t m, + std::int64_t n, std::int64_t k, float alpha, const std::uint8_t *a, + std::int64_t lda, std::uint8_t ao, const std::uint8_t *b, std::int64_t ldb, + std::uint8_t bo, float beta, std::int32_t *c, std::int64_t ldc, + const std::int32_t *co, const std::vector &dependencies) { + throw unimplemented("blas", "gemm_bias", " for USM"); +} diff --git a/src/blas/backends/portblas/portblas_level1.cxx b/src/blas/backends/portblas/portblas_level1.cxx new file mode 100644 index 000000000..e1e1f2f60 --- /dev/null +++ b/src/blas/backends/portblas/portblas_level1.cxx @@ -0,0 +1,410 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +// Buffer APIs + +void dotc(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &result) { + throw unimplemented("blas", "dotc", ""); +} + +void dotu(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &result) { + throw unimplemented("blas", "dotu", ""); +} + +void iamax(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &result) { + CALL_PORTBLAS_FN(::blas::_iamax, queue, n, x, incx, result); +} + +void iamax(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, sycl::buffer &result) { + throw unimplemented("blas", "iamax", ""); +} + +void iamin(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &result) { + CALL_PORTBLAS_FN(::blas::_iamin, queue, n, x, incx, result); +} + +void iamin(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, sycl::buffer &result) { + throw unimplemented("blas", "iamin", ""); +} + +void asum(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, sycl::buffer &result) { + throw unimplemented("blas", "asum", ""); +} + +void asum(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &result) { + // portBLAS asum implementation requires that result is initialized to zero + // before performing the computation. + queue.submit([&](sycl::handler &cgh) { + auto result_acc = result.template get_access(cgh); + cgh.single_task([=]() { result_acc[0] = real_t(0); }); + }); + CALL_PORTBLAS_FN(::blas::_asum, queue, n, x, incx, result); +} + +void axpy(sycl::queue &queue, std::int64_t n, real_t alpha, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy) { + CALL_PORTBLAS_FN(::blas::_axpy, queue, n, alpha, x, incx, y, incy); +} + +void axpy(sycl::queue &queue, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy) { + throw unimplemented("blas", "axpy", "for complex"); +} + +void axpby(sycl::queue &queue, std::int64_t n, real_t alpha, sycl::buffer &x, + std::int64_t incx, real_t beta, sycl::buffer &y, std::int64_t incy) { + throw unimplemented("blas", "axpby", ""); +} + +void axpby(sycl::queue &queue, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + throw unimplemented("blas", "axpby", ""); +} + +void copy(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &y, std::int64_t incy) { + CALL_PORTBLAS_FN(::blas::_copy, queue, n, x, incx, y, incy); +} + +void copy(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, sycl::buffer, 1> &y, std::int64_t incy) { + throw unimplemented("blas", "copy", " for complex."); +} + +void dot(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &y, std::int64_t incy, sycl::buffer &result) { + // portBLAS dot implementation requires that result is initialized to zero + // before performing the computation. + queue.submit([&](sycl::handler &cgh) { + auto result_acc = result.template get_access(cgh); + cgh.single_task([=]() { result_acc[0] = real_t(0); }); + }); + CALL_PORTBLAS_FN(::blas::_dot, queue, n, x, incx, y, incy, result); +} + +#ifdef ENABLE_MIXED_PRECISION_WITH_DOUBLE +void dot(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &y, std::int64_t incy, sycl::buffer &result) { + throw unimplemented("blas", "dot", " for unmatched return type"); +} +#endif + +void sdsdot(sycl::queue &queue, std::int64_t n, real_t sb, sycl::buffer &x, + std::int64_t incx, sycl::buffer &y, std::int64_t incy, + sycl::buffer &result) { + // portBLAS sdsdot implementation requires that result is initialized to zero + // before performing the computation. + queue.submit([&](sycl::handler &cgh) { + auto result_acc = result.template get_access(cgh); + cgh.single_task([=]() { result_acc[0] = real_t(0); }); + }); + CALL_PORTBLAS_FN(::blas::_sdsdot, queue, n, sb, x, incx, y, incy, result); +} + +void nrm2(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, sycl::buffer &result) { + throw unimplemented("blas", "nrm2", " for complex"); +} + +void nrm2(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &result) { + // portBLAS nrm2 implementation requires that result is initialized to zero + // before performing the computation. + queue.submit([&](sycl::handler &cgh) { + auto result_acc = result.template get_access(cgh); + cgh.single_task([=]() { result_acc[0] = real_t(0); }); + }); + CALL_PORTBLAS_FN(::blas::_nrm2, queue, n, x, incx, result); +} + +void rot(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, sycl::buffer, 1> &y, std::int64_t incy, real_t c, + real_t s) { + throw unimplemented("blas", "rot", " for complex"); +} + +void rot(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &y, std::int64_t incy, real_t c, real_t s) { + CALL_PORTBLAS_FN(::blas::_rot, queue, n, x, incx, y, incy, c, s); +} + +void rotg(sycl::queue &queue, sycl::buffer &a, sycl::buffer &b, + sycl::buffer &c, sycl::buffer &s) { + CALL_PORTBLAS_FN(::blas::_rotg, queue, a, b, c, s); +} + +void rotg(sycl::queue &queue, sycl::buffer, 1> &a, + sycl::buffer, 1> &b, sycl::buffer &c, + sycl::buffer, 1> &s) { + throw unimplemented("blas", "rotg", " for complex"); +} + +void rotm(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &y, std::int64_t incy, sycl::buffer ¶m) { + CALL_PORTBLAS_FN(::blas::_rotm, queue, n, x, incx, y, incy, param); +} + +void rotmg(sycl::queue &queue, sycl::buffer &d1, sycl::buffer &d2, + sycl::buffer &x1, real_t y1, sycl::buffer ¶m) { + sycl::buffer y1_buffer(&y1, sycl::range<1>{ 1 }); + CALL_PORTBLAS_FN(::blas::_rotmg, queue, d1, d2, x1, y1_buffer, param); +} + +void scal(sycl::queue &queue, std::int64_t n, real_t alpha, sycl::buffer &x, + std::int64_t incx) { + CALL_PORTBLAS_FN(::blas::_scal, queue, n, alpha, x, incx); +} + +void scal(sycl::queue &queue, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx) { + throw unimplemented("blas", "scal", " for complex"); +} + +void scal(sycl::queue &queue, std::int64_t n, real_t alpha, + sycl::buffer, 1> &x, std::int64_t incx) { + throw unimplemented("blas", "scal", " for complex"); +} + +void swap(sycl::queue &queue, std::int64_t n, sycl::buffer &x, std::int64_t incx, + sycl::buffer &y, std::int64_t incy) { + CALL_PORTBLAS_FN(::blas::_swap, queue, n, x, incx, y, incy); +} + +void swap(sycl::queue &queue, std::int64_t n, sycl::buffer, 1> &x, + std::int64_t incx, sycl::buffer, 1> &y, std::int64_t incy) { + throw unimplemented("blas", "swap", " for complex"); +} + +// USM APIs + +sycl::event dotc(sycl::queue &queue, std::int64_t n, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *result, const std::vector &dependencies) { + throw unimplemented("blas", "dotc", " for USM"); +} + +sycl::event dotu(sycl::queue &queue, std::int64_t n, const std::complex *x, + std::int64_t incx, const std::complex *y, std::int64_t incy, + std::complex *result, const std::vector &dependencies) { + throw unimplemented("blas", "dotu", " for USM"); +} + +sycl::event iamax(sycl::queue &queue, std::int64_t n, const real_t *x, std::int64_t incx, + std::int64_t *result, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_iamax, queue, n, x, incx, result, dependencies); +} + +sycl::event iamax(sycl::queue &queue, std::int64_t n, const std::complex *x, + std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + throw unimplemented("blas", "iamax", " for USM"); +} + +sycl::event iamin(sycl::queue &queue, std::int64_t n, const real_t *x, std::int64_t incx, + std::int64_t *result, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_iamin, queue, n, x, incx, result, dependencies); +} + +sycl::event iamin(sycl::queue &queue, std::int64_t n, const std::complex *x, + std::int64_t incx, std::int64_t *result, + const std::vector &dependencies) { + throw unimplemented("blas", "iamin", " for USM"); +} + +sycl::event asum(sycl::queue &queue, std::int64_t n, const std::complex *x, + std::int64_t incx, real_t *result, const std::vector &dependencies) { + throw unimplemented("blas", "asum", " for USM"); +} + +sycl::event asum(sycl::queue &queue, std::int64_t n, const real_t *x, std::int64_t incx, + real_t *result, const std::vector &dependencies) { + // portBLAS asum implementation requires result to be initializes to zero + // before starting the computation. + auto init_res_val = queue.submit( + [&](sycl::handler &cgh) { cgh.single_task([=]() { result[0] = real_t(0); }); }); + std::vector new_dependencies = dependencies; + new_dependencies.push_back(init_res_val); + CALL_PORTBLAS_USM_FN(::blas::_asum, queue, n, x, incx, result, new_dependencies); +} + +sycl::event axpy(sycl::queue &queue, std::int64_t n, real_t alpha, const real_t *x, + std::int64_t incx, real_t *y, std::int64_t incy, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_axpy, queue, n, alpha, x, incx, y, incy, dependencies); +} + +sycl::event axpy(sycl::queue &queue, std::int64_t n, std::complex alpha, + const std::complex *x, std::int64_t incx, std::complex *y, + std::int64_t incy, const std::vector &dependencies) { + throw unimplemented("blas", "axpy", " for USM"); +} + +sycl::event axpby(sycl::queue &queue, std::int64_t n, real_t alpha, const real_t *x, + std::int64_t incx, const real_t beta, real_t *y, std::int64_t incy, + const std::vector &dependencies) { + throw unimplemented("blas", "axpby", " for USM"); +} + +sycl::event axpby(sycl::queue &queue, std::int64_t n, std::complex alpha, + const std::complex *x, std::int64_t incx, const std::complex beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + throw unimplemented("blas", "axpby", " for USM"); +} + +sycl::event copy(sycl::queue &queue, std::int64_t n, const real_t *x, std::int64_t incx, real_t *y, + std::int64_t incy, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_copy, queue, n, x, incx, y, incy, dependencies); +} + +sycl::event copy(sycl::queue &queue, std::int64_t n, const std::complex *x, + std::int64_t incx, std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + throw unimplemented("blas", "copy", " for USM"); +} + +sycl::event dot(sycl::queue &queue, std::int64_t n, const real_t *x, std::int64_t incx, + const real_t *y, std::int64_t incy, real_t *result, + const std::vector &dependencies) { + // portBLAS dot implementation requires result to be initializes to zero + // before starting the computation. + auto init_res_val = queue.submit( + [&](sycl::handler &cgh) { cgh.single_task([=]() { result[0] = real_t(0); }); }); + std::vector new_dependencies = dependencies; + new_dependencies.emplace_back(init_res_val); + CALL_PORTBLAS_USM_FN(::blas::_dot, queue, n, x, incx, y, incy, result, new_dependencies); +} + +#ifdef ENABLE_MIXED_PRECISION_WITH_DOUBLE +sycl::event dot(sycl::queue &queue, std::int64_t n, const float *x, std::int64_t incx, + const float *y, std::int64_t incy, double *result, + const std::vector &dependencies) { + throw unimplemented("blas", "dot", " for USM"); +} +#endif + +sycl::event sdsdot(sycl::queue &queue, std::int64_t n, real_t sb, const real_t *x, + std::int64_t incx, const real_t *y, std::int64_t incy, real_t *result, + const std::vector &dependencies) { + // portBLAS sdsdot implementation requires result to be initializes to zero + // before starting the computation. + auto init_res_val = queue.submit( + [&](sycl::handler &cgh) { cgh.single_task([=]() { result[0] = real_t(0); }); }); + std::vector new_dependencies = dependencies; + new_dependencies.emplace_back(init_res_val); + CALL_PORTBLAS_USM_FN(::blas::_sdsdot, queue, n, sb, x, incx, y, incy, result, new_dependencies); +} + +sycl::event nrm2(sycl::queue &queue, std::int64_t n, const std::complex *x, + std::int64_t incx, real_t *result, const std::vector &dependencies) { + throw unimplemented("blas", "nrm2", " for USM"); +} + +sycl::event nrm2(sycl::queue &queue, std::int64_t n, const real_t *x, std::int64_t incx, + real_t *result, const std::vector &dependencies) { + // portBLAS nrm2 implementation requires result to be initializes to zero + // before starting the computation. + auto init_res_val = queue.submit( + [&](sycl::handler &cgh) { cgh.single_task([=]() { result[0] = real_t(0); }); }); + std::vector new_dependencies = dependencies; + new_dependencies.push_back(init_res_val); + CALL_PORTBLAS_USM_FN(::blas::_nrm2, queue, n, x, incx, result, new_dependencies); +} + +sycl::event rot(sycl::queue &queue, std::int64_t n, std::complex *x, std::int64_t incx, + std::complex *y, std::int64_t incy, real_t c, real_t s, + const std::vector &dependencies) { + throw unimplemented("blas", "rot", " for USM"); +} + +sycl::event rot(sycl::queue &queue, std::int64_t n, real_t *x, std::int64_t incx, real_t *y, + std::int64_t incy, real_t c, real_t s, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_rot, queue, n, x, incx, y, incy, c, s, dependencies); +} + +sycl::event rotg(sycl::queue &queue, real_t *a, real_t *b, real_t *c, real_t *s, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_rotg, queue, a, b, c, s, dependencies); +} + +sycl::event rotg(sycl::queue &queue, std::complex *a, std::complex *b, real_t *c, + std::complex *s, const std::vector &dependencies) { + throw unimplemented("blas", "rotg", " for USM"); +} + +sycl::event rotm(sycl::queue &queue, std::int64_t n, real_t *x, std::int64_t incx, real_t *y, + std::int64_t incy, real_t *param, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_rotm, queue, n, x, incx, y, incy, param, dependencies); +} + +sycl::event rotmg(sycl::queue &queue, real_t *d1, real_t *d2, real_t *x1, real_t y1, real_t *param, + const std::vector &dependencies) { + auto y_d = + (real_t *)sycl::malloc_device(sizeof(real_t), queue.get_device(), queue.get_context()); + auto copy_in_event = queue.memcpy(y_d, &y1, sizeof(real_t), dependencies); + auto rotmg_event = std::invoke([&]() -> sycl::event { + CALL_PORTBLAS_USM_FN(::blas::_rotmg, queue, d1, d2, x1, y_d, param, + std::vector{ copy_in_event }); + }); + auto free_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(rotmg_event); + cgh.host_task([=]() { sycl::free(y_d, queue); }); + }); + return free_event; +} + +sycl::event scal(sycl::queue &queue, std::int64_t n, real_t alpha, real_t *x, std::int64_t incx, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_scal, queue, n, alpha, x, incx, dependencies); +} + +sycl::event scal(sycl::queue &queue, std::int64_t n, std::complex alpha, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + throw unimplemented("blas", "scal", " for USM"); +} + +sycl::event scal(sycl::queue &queue, std::int64_t n, real_t alpha, std::complex *x, + std::int64_t incx, const std::vector &dependencies) { + throw unimplemented("blas", "scal", " for USM"); +} + +sycl::event swap(sycl::queue &queue, std::int64_t n, real_t *x, std::int64_t incx, real_t *y, + std::int64_t incy, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_swap, queue, n, x, incx, y, incy, dependencies); +} + +sycl::event swap(sycl::queue &queue, std::int64_t n, std::complex *x, std::int64_t incx, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + throw unimplemented("blas", "swap", " for USM"); +} diff --git a/src/blas/backends/portblas/portblas_level1_double.cpp b/src/blas/backends/portblas/portblas_level1_double.cpp new file mode 100644 index 000000000..4c99f98c6 --- /dev/null +++ b/src/blas/backends/portblas/portblas_level1_double.cpp @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "portblas_common.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { + +using real_t = double; +#define ENABLE_MIXED_PRECISION_WITH_DOUBLE + +namespace column_major { + +#define COLUMN_MAJOR +constexpr bool is_column_major() { + return true; +} +#include "portblas_level1.cxx" +#undef COLUMN_MAJOR + +} // namespace column_major +namespace row_major { + +#define ROW_MAJOR +constexpr bool is_column_major() { + return false; +} +#include "portblas_level1.cxx" +#undef ROW_MAJOR + +#undef ENABLE_MIXED_PRECISION_WITH_DOUBLE +} // namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi diff --git a/src/blas/backends/portblas/portblas_level1_float.cpp b/src/blas/backends/portblas/portblas_level1_float.cpp new file mode 100644 index 000000000..744729f1a --- /dev/null +++ b/src/blas/backends/portblas/portblas_level1_float.cpp @@ -0,0 +1,60 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "portblas_common.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { + +using real_t = float; + +namespace column_major { + +#define COLUMN_MAJOR +constexpr bool is_column_major() { + return true; +} +#include "portblas_level1.cxx" +#undef COLUMN_MAJOR + +} // namespace column_major +namespace row_major { + +#define ROW_MAJOR +constexpr bool is_column_major() { + return false; +} +#include "portblas_level1.cxx" +#undef ROW_MAJOR + +} // namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi diff --git a/src/blas/backends/portblas/portblas_level2.cxx b/src/blas/backends/portblas/portblas_level2.cxx new file mode 100644 index 000000000..b3d8b6766 --- /dev/null +++ b/src/blas/backends/portblas/portblas_level2.cxx @@ -0,0 +1,470 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +// Buffer APIs + +void gemv(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + real_t alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, + std::int64_t incx, real_t beta, sycl::buffer &y, std::int64_t incy) { + CALL_PORTBLAS_FN(::blas::_gemv, queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void gemv(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + throw unimplemented("blas", "gemv", " for complex"); +} + +void gbmv(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::int64_t kl, std::int64_t ku, real_t alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &x, std::int64_t incx, real_t beta, + sycl::buffer &y, std::int64_t incy) { + CALL_PORTBLAS_FN(::blas::_gbmv, queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, + incy); +} + +void gbmv(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::int64_t kl, std::int64_t ku, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + throw unimplemented("blas", "gbmv", " for complex"); +} + +void ger(sycl::queue &queue, std::int64_t m, std::int64_t n, real_t alpha, + sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy, sycl::buffer &a, std::int64_t lda) { + CALL_PORTBLAS_FN(::blas::_ger, queue, m, n, alpha, x, incx, y, incy, a, lda); +} + +void gerc(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a, std::int64_t lda) { + throw unimplemented("blas", "gerc", ""); +} + +void geru(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a, std::int64_t lda) { + throw unimplemented("blas", "geru", ""); +} + +void hbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, std::int64_t k, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + throw unimplemented("blas", "hbmv", ""); +} + +void hemv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + throw unimplemented("blas", "hemv", ""); +} + +void her(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &a, std::int64_t lda) { + throw unimplemented("blas", "her", ""); +} + +void her2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a, std::int64_t lda) { + throw unimplemented("blas", "her2", ""); +} + +void hpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, std::int64_t incx, std::complex beta, + sycl::buffer, 1> &y, std::int64_t incy) { + throw unimplemented("blas", "hpmv", ""); +} + +void hpr(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &a) { + throw unimplemented("blas", "hpr", ""); +} + +void hpr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &x, std::int64_t incx, + sycl::buffer, 1> &y, std::int64_t incy, + sycl::buffer, 1> &a) { + throw unimplemented("blas", "hpr2", ""); +} + +void sbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, std::int64_t k, + real_t alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &x, + std::int64_t incx, real_t beta, sycl::buffer &y, std::int64_t incy) { + CALL_PORTBLAS_FN(::blas::_sbmv, queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, + incy); +} + +void symv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &x, + std::int64_t incx, real_t beta, sycl::buffer &y, std::int64_t incy) { + CALL_PORTBLAS_FN(::blas::_symv, queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, incy); +} + +void syr(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + sycl::buffer &x, std::int64_t incx, sycl::buffer &a, + std::int64_t lda) { + CALL_PORTBLAS_FN(::blas::_syr, queue, upper_lower, n, alpha, x, incx, a, lda); +} + +void syr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy, sycl::buffer &a, std::int64_t lda) { + CALL_PORTBLAS_FN(::blas::_syr2, queue, upper_lower, n, alpha, x, incx, y, incy, a, lda); +} + +void spmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + sycl::buffer &a, sycl::buffer &x, std::int64_t incx, real_t beta, + sycl::buffer &y, std::int64_t incy) { + CALL_PORTBLAS_FN(::blas::_spmv, queue, upper_lower, n, alpha, a, x, incx, beta, y, incy); +} + +void spr(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + sycl::buffer &x, std::int64_t incx, sycl::buffer &a) { + CALL_PORTBLAS_FN(::blas::_spr, queue, upper_lower, n, alpha, x, incx, a); +} + +void spr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + sycl::buffer &x, std::int64_t incx, sycl::buffer &y, + std::int64_t incy, sycl::buffer &a) { + CALL_PORTBLAS_FN(::blas::_spr2, queue, upper_lower, n, alpha, x, incx, y, incy, a); +} + +void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, + std::int64_t lda, sycl::buffer &x, std::int64_t incx) { + CALL_PORTBLAS_FN(::blas::_tbmv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +void tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx) { + throw unimplemented("blas", "tbmv", ""); +} + +void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, sycl::buffer &a, + std::int64_t lda, sycl::buffer &x, std::int64_t incx) { + CALL_PORTBLAS_FN(::blas::_tbsv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx); +} + +void tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &x, std::int64_t incx) { + throw unimplemented("blas", "tbsv", ""); +} + +void tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer &a, + sycl::buffer &x, std::int64_t incx) { + CALL_PORTBLAS_FN(::blas::_tpmv, queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +void tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, std::int64_t incx) { + throw unimplemented("blas", "tpmv", ""); +} + +void tpsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer &a, + sycl::buffer &x, std::int64_t incx) { + CALL_PORTBLAS_FN(::blas::_tpsv, queue, upper_lower, trans, unit_diag, n, a, x, incx); +} + +void tpsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, std::int64_t incx) { + throw unimplemented("blas", "tpsv", ""); +} + +void trmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx) { + CALL_PORTBLAS_FN(::blas::_trmv, queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +void trmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx) { + throw unimplemented("blas", "trmv", " for complex"); +} + +void trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer &a, std::int64_t lda, + sycl::buffer &x, std::int64_t incx) { + CALL_PORTBLAS_FN(::blas::_trsv, queue, upper_lower, trans, unit_diag, n, a, lda, x, incx); +} + +void trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &x, std::int64_t incx) { + throw unimplemented("blas", "trsv", ""); +} + +// USM APIs + +sycl::event gemv(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + real_t alpha, const real_t *a, std::int64_t lda, const real_t *x, + std::int64_t incx, real_t beta, real_t *y, std::int64_t incy, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_gemv, queue, trans, m, n, alpha, a, lda, x, incx, beta, y, incy, + dependencies); +} + +sycl::event gemv(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + const std::complex *x, std::int64_t incx, std::complex beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + throw unimplemented("blas", "gemv", " for USM"); +} + +sycl::event gbmv(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::int64_t kl, std::int64_t ku, real_t alpha, const real_t *a, std::int64_t lda, + const real_t *x, std::int64_t incx, real_t beta, real_t *y, std::int64_t incy, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_gbmv, queue, trans, m, n, kl, ku, alpha, a, lda, x, incx, beta, y, + incy, dependencies); +} + +sycl::event gbmv(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + std::int64_t kl, std::int64_t ku, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *x, + std::int64_t incx, std::complex beta, std::complex *y, + std::int64_t incy, const std::vector &dependencies) { + throw unimplemented("blas", "gbmv", " for USM"); +} + +sycl::event ger(sycl::queue &queue, std::int64_t m, std::int64_t n, real_t alpha, const real_t *x, + std::int64_t incx, const real_t *y, std::int64_t incy, real_t *a, std::int64_t lda, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_ger, queue, m, n, alpha, x, incx, y, incy, a, lda, dependencies); +} + +sycl::event gerc(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *x, std::int64_t incx, const std::complex *y, + std::int64_t incy, std::complex *a, std::int64_t lda, + const std::vector &dependencies) { + throw unimplemented("blas", "gerc", " for USM"); +} + +sycl::event geru(sycl::queue &queue, std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *x, std::int64_t incx, const std::complex *y, + std::int64_t incy, std::complex *a, std::int64_t lda, + const std::vector &dependencies) { + throw unimplemented("blas", "geru", " for USM"); +} + +sycl::event hbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, std::int64_t k, + std::complex alpha, const std::complex *a, std::int64_t lda, + const std::complex *x, std::int64_t incx, std::complex beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + throw unimplemented("blas", "hbmv", " for USM"); +} + +sycl::event hemv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + const std::complex *x, std::int64_t incx, std::complex beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + throw unimplemented("blas", "hemv", " for USM"); +} + +sycl::event her(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + const std::complex *x, std::int64_t incx, std::complex *a, + std::int64_t lda, const std::vector &dependencies) { + throw unimplemented("blas", "her", " for USM"); +} + +sycl::event her2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + std::int64_t lda, const std::vector &dependencies) { + throw unimplemented("blas", "her2", " for USM"); +} + +sycl::event hpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *a, + const std::complex *x, std::int64_t incx, std::complex beta, + std::complex *y, std::int64_t incy, + const std::vector &dependencies) { + throw unimplemented("blas", "hpmv", " for USM"); +} + +sycl::event hpr(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + const std::complex *x, std::int64_t incx, std::complex *a, + const std::vector &dependencies) { + throw unimplemented("blas", "hpr", " for USM"); +} + +sycl::event hpr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, + std::complex alpha, const std::complex *x, std::int64_t incx, + const std::complex *y, std::int64_t incy, std::complex *a, + const std::vector &dependencies) { + throw unimplemented("blas", "hpr2", " for USM"); +} + +sycl::event sbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, std::int64_t k, + real_t alpha, const real_t *a, std::int64_t lda, const real_t *x, + std::int64_t incx, real_t beta, real_t *y, std::int64_t incy, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_sbmv, queue, upper_lower, n, k, alpha, a, lda, x, incx, beta, y, + incy, dependencies); +} + +sycl::event symv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + const real_t *a, std::int64_t lda, const real_t *x, std::int64_t incx, real_t beta, + real_t *y, std::int64_t incy, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_symv, queue, upper_lower, n, alpha, a, lda, x, incx, beta, y, + incy, dependencies); +} + +sycl::event syr(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + const real_t *x, std::int64_t incx, real_t *a, std::int64_t lda, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_syr, queue, upper_lower, n, alpha, x, incx, a, lda, dependencies); +} + +sycl::event syr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + const real_t *x, std::int64_t incx, const real_t *y, std::int64_t incy, real_t *a, + std::int64_t lda, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_syr2, queue, upper_lower, n, alpha, x, incx, y, incy, a, lda, + dependencies); +} + +sycl::event spmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + const real_t *a, const real_t *x, std::int64_t incx, real_t beta, real_t *y, + std::int64_t incy, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_spmv, queue, upper_lower, n, alpha, a, x, incx, beta, y, incy, + dependencies); +} + +sycl::event spr(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + const real_t *x, std::int64_t incx, real_t *a, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_spr, queue, upper_lower, n, alpha, x, incx, a, dependencies); +} + +sycl::event spr2(sycl::queue &queue, oneapi::mkl::uplo upper_lower, std::int64_t n, real_t alpha, + const real_t *x, std::int64_t incx, const real_t *y, std::int64_t incy, real_t *a, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_spr2, queue, upper_lower, n, alpha, x, incx, y, incy, a, + dependencies); +} + +sycl::event tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, const real_t *a, + std::int64_t lda, real_t *x, std::int64_t incx, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_tbmv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx, + dependencies); +} + +sycl::event tbmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, + const std::complex *a, std::int64_t lda, std::complex *x, + std::int64_t incx, const std::vector &dependencies) { + throw unimplemented("blas", "tbmv", " for USM"); +} + +sycl::event tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, const real_t *a, + std::int64_t lda, real_t *x, std::int64_t incx, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_tbsv, queue, upper_lower, trans, unit_diag, n, k, a, lda, x, incx, + dependencies); +} + +sycl::event tbsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, std::int64_t k, + const std::complex *a, std::int64_t lda, std::complex *x, + std::int64_t incx, const std::vector &dependencies) { + throw unimplemented("blas", "tbsv", " for USM"); +} + +sycl::event tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, const real_t *a, real_t *x, + std::int64_t incx, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_tpmv, queue, upper_lower, trans, unit_diag, n, a, x, incx, + dependencies); +} + +sycl::event tpmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, const std::complex *a, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + throw unimplemented("blas", "tpmv", " for USM"); +} + +sycl::event tpsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, const real_t *a, real_t *x, + std::int64_t incx, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_tpsv, queue, upper_lower, trans, unit_diag, n, a, x, incx, + dependencies); +} + +sycl::event tpsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, const std::complex *a, + std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + throw unimplemented("blas", "tpsv", " for USM"); +} + +sycl::event trmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, const real_t *a, std::int64_t lda, + real_t *x, std::int64_t incx, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_trmv, queue, upper_lower, trans, unit_diag, n, a, lda, x, incx, + dependencies); +} + +sycl::event trmv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, const std::complex *a, + std::int64_t lda, std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + throw unimplemented("blas", "trmv", " for USM"); +} + +sycl::event trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, const real_t *a, std::int64_t lda, + real_t *x, std::int64_t incx, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_trsv, queue, upper_lower, trans, unit_diag, n, a, lda, x, incx, + dependencies); +} + +sycl::event trsv(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + oneapi::mkl::diag unit_diag, std::int64_t n, const std::complex *a, + std::int64_t lda, std::complex *x, std::int64_t incx, + const std::vector &dependencies) { + throw unimplemented("blas", "trsv", " for USM"); +} diff --git a/src/blas/backends/portblas/portblas_level2_double.cpp b/src/blas/backends/portblas/portblas_level2_double.cpp new file mode 100644 index 000000000..092aa0c59 --- /dev/null +++ b/src/blas/backends/portblas/portblas_level2_double.cpp @@ -0,0 +1,60 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "portblas_common.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { + +using real_t = double; + +namespace column_major { + +#define COLUMN_MAJOR +constexpr bool is_column_major() { + return true; +} +#include "portblas_level2.cxx" +#undef COLUMN_MAJOR + +} // namespace column_major +namespace row_major { + +#define ROW_MAJOR +constexpr bool is_column_major() { + return false; +} +#include "portblas_level2.cxx" +#undef ROW_MAJOR + +} // namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi diff --git a/src/blas/backends/portblas/portblas_level2_float.cpp b/src/blas/backends/portblas/portblas_level2_float.cpp new file mode 100644 index 000000000..7308c05da --- /dev/null +++ b/src/blas/backends/portblas/portblas_level2_float.cpp @@ -0,0 +1,60 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "portblas_common.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { + +using real_t = float; + +namespace column_major { + +#define COLUMN_MAJOR +constexpr bool is_column_major() { + return true; +} +#include "portblas_level2.cxx" +#undef COLUMN_MAJOR + +} // namespace column_major +namespace row_major { + +#define ROW_MAJOR +constexpr bool is_column_major() { + return false; +} +#include "portblas_level2.cxx" +#undef ROW_MAJOR + +} // namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi diff --git a/src/blas/backends/portblas/portblas_level3.cxx b/src/blas/backends/portblas/portblas_level3.cxx new file mode 100644 index 000000000..4eeb1e8f1 --- /dev/null +++ b/src/blas/backends/portblas/portblas_level3.cxx @@ -0,0 +1,451 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +// Buffer APIs + +void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, real_t alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &b, std::int64_t ldb, real_t beta, + sycl::buffer &c, std::int64_t ldc) { + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); +} + +void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + using sycl_complex_real_t = sycl::ext::oneapi::experimental::complex; + if (transa == oneapi::mkl::transpose::conjtrans || + transb == oneapi::mkl::transpose::conjtrans) { + throw unimplemented("blas", "gemm", "Conjugate Transpose unsupported yet on portBLAS"); + } + // Intermediate buffers for conversion purposes as portBLAS expects sycl::complex instead of std::complex + sycl::buffer a_pb{ sycl::range<1>(a.size()) }; + sycl::buffer b_pb{ sycl::range<1>(b.size()) }; + sycl::buffer c_pb{ sycl::range<1>(c.size()) }; + + sycl::accessor, 1, sycl::access::mode::read> a_acc(a); + sycl::accessor a_pb_acc(a_pb); + queue.copy(a_acc, a_pb_acc); + + sycl::accessor, 1, sycl::access::mode::read> b_acc(b); + sycl::accessor b_pb_acc(b_pb); + queue.copy(b_acc, b_pb_acc); + + sycl::accessor, 1, sycl::access::mode::read> c_acc(c); + sycl::accessor c_pb_acc(c_pb); + queue.copy(c_acc, c_pb_acc); + + CALL_PORTBLAS_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a_pb, lda, b_pb, ldb, + beta, c_pb, ldc); + + // Copy c_pb back to c + sycl::accessor, 1, sycl::access::mode::write> out_acc(c); + sycl::accessor out_pb_acc(c_pb); + queue.copy(out_pb_acc, out_acc); +} + +void symm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + std::int64_t m, std::int64_t n, real_t alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &b, std::int64_t ldb, real_t beta, + sycl::buffer &c, std::int64_t ldc) { + CALL_PORTBLAS_FN(::blas::_symm, queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, + beta, c, ldc); +} + +void symm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + throw unimplemented("blas", "symm", ""); +} + +void hemm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + throw unimplemented("blas", "hemm", ""); +} + +void syrk(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, real_t alpha, sycl::buffer &a, + std::int64_t lda, real_t beta, sycl::buffer &c, std::int64_t ldc) { + throw unimplemented("blas", "syrk", ""); +} + +void syrk(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + throw unimplemented("blas", "syrk", ""); +} + +void herk(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, real_t alpha, sycl::buffer, 1> &a, + std::int64_t lda, real_t beta, sycl::buffer, 1> &c, + std::int64_t ldc) { + throw unimplemented("blas", "herk", ""); +} + +void syr2k(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, real_t alpha, sycl::buffer &a, + std::int64_t lda, sycl::buffer &b, std::int64_t ldb, real_t beta, + sycl::buffer &c, std::int64_t ldc) { + throw unimplemented("blas", "syr2k", ""); +} + +void syr2k(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + throw unimplemented("blas", "syr2k", ""); +} + +void her2k(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, real_t beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + throw unimplemented("blas", "her2k", ""); +} + +void trmm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, + real_t alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb) { + throw unimplemented("blas", "trmm", ""); +} + +void trmm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb) { + throw unimplemented("blas", "trmm", ""); +} + +void trsm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, + real_t alpha, sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb) { + CALL_PORTBLAS_FN(::blas::_trsm, queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, + a, lda, b, ldb); +} + +void trsm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb) { + throw unimplemented("blas", "trsm", " for complex"); +} + +void gemmt(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t n, std::int64_t k, real_t alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb, real_t beta, sycl::buffer &c, std::int64_t ldc) { + throw unimplemented("blas", "gemmt", ""); +} + +void gemmt(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t n, std::int64_t k, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + sycl::buffer, 1> &b, std::int64_t ldb, std::complex beta, + sycl::buffer, 1> &c, std::int64_t ldc) { + throw unimplemented("blas", "gemmt", ""); +} + +void omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, real_t alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb) { + CALL_PORTBLAS_FN(::blas::_omatcopy, queue, trans, m, n, alpha, a, lda, b, ldb); +} + +void omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, sycl::buffer, 1> &b, std::int64_t ldb) { + throw unimplemented("blas", "omatcopy", ""); +} + +void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, real_t alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stridea, + sycl::buffer &b, std::int64_t ldb, std::int64_t strideb) { + CALL_PORTBLAS_FN(::blas::_omatcopy2, queue, trans, m, n, alpha, a, lda, stridea, b, ldb, + strideb); +} + +void omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + throw unimplemented("blas", "omatcopy2", ""); +} + +void imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, real_t alpha, + sycl::buffer &ab, std::int64_t lda, std::int64_t ldb) { + throw unimplemented("blas", "imatcopy", ""); +} + +void imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &ab, + std::int64_t lda, std::int64_t ldb) { + throw unimplemented("blas", "imatcopy", ""); +} + +void omatadd(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + real_t alpha, sycl::buffer &a, std::int64_t lda, real_t beta, + sycl::buffer &b, std::int64_t ldb, sycl::buffer &c, + std::int64_t ldc) { + CALL_PORTBLAS_FN(::blas::_omatadd, queue, transa, transb, m, n, alpha, a, lda, beta, b, ldb, c, + ldc); +} + +void omatadd(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, + std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, + std::complex beta, sycl::buffer, 1> &b, std::int64_t ldb, + sycl::buffer, 1> &c, std::int64_t ldc) { + throw unimplemented("blas", "omatadd", ""); +} + +// USM APIs + +sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, real_t alpha, const real_t *a, + std::int64_t lda, const real_t *b, std::int64_t ldb, real_t beta, real_t *c, + std::int64_t ldc, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +} + +sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + if (transa == oneapi::mkl::transpose::conjtrans || + transb == oneapi::mkl::transpose::conjtrans) { + throw unimplemented("blas", "gemm", "Conjugate Transpose unsupported yet on portBLAS"); + } + CALL_PORTBLAS_USM_FN(::blas::_gemm, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); +} + +sycl::event symm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + std::int64_t m, std::int64_t n, real_t alpha, const real_t *a, std::int64_t lda, + const real_t *b, std::int64_t ldb, real_t beta, real_t *c, std::int64_t ldc, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_symm, queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, + beta, c, ldc, dependencies); +} + +sycl::event symm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + throw unimplemented("blas", "symm", " for USM"); +} + +sycl::event hemm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + throw unimplemented("blas", "hemm", " for USM"); +} + +sycl::event syrk(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, real_t alpha, const real_t *a, std::int64_t lda, + real_t beta, real_t *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk", " for USM"); +} + +sycl::event syrk(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, std::complex beta, + std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "syrk", " for USM"); +} + +sycl::event herk(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, real_t alpha, const std::complex *a, + std::int64_t lda, real_t beta, std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "herk", " for USM"); +} + +sycl::event syr2k(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, real_t alpha, const real_t *a, std::int64_t lda, + const real_t *b, std::int64_t ldb, real_t beta, real_t *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "syr2k", " for USM"); +} + +sycl::event syr2k(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, std::complex beta, std::complex *c, + std::int64_t ldc, const std::vector &dependencies) { + throw unimplemented("blas", "syr2k", " for USM"); +} + +sycl::event her2k(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, + std::int64_t n, std::int64_t k, std::complex alpha, + const std::complex *a, std::int64_t lda, const std::complex *b, + std::int64_t ldb, real_t beta, std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "her2k", " for USM"); +} + +sycl::event trmm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, + std::int64_t n, real_t alpha, const real_t *a, std::int64_t lda, real_t *b, + std::int64_t ldb, const std::vector &dependencies) { + throw unimplemented("blas", "trmm", " for USM"); +} + +sycl::event trmm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, + const std::vector &dependencies) { + throw unimplemented("blas", "trmm", " for USM"); +} + +sycl::event trsm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, + std::int64_t n, real_t alpha, const real_t *a, std::int64_t lda, real_t *b, + std::int64_t ldb, const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_trsm, queue, left_right, upper_lower, trans, unit_diag, m, n, + alpha, a, lda, b, ldb, dependencies); +} + +sycl::event trsm(sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, + oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::complex *b, std::int64_t ldb, + const std::vector &dependencies) { + throw unimplemented("blas", "trsm", " for USM"); +} + +sycl::event gemmt(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t n, std::int64_t k, real_t alpha, + const real_t *a, std::int64_t lda, const real_t *b, std::int64_t ldb, real_t beta, + real_t *c, std::int64_t ldc, const std::vector &dependencies) { + throw unimplemented("blas", "gemmt", " for USM"); +} + +sycl::event gemmt(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, + oneapi::mkl::transpose transb, std::int64_t n, std::int64_t k, + std::complex alpha, const std::complex *a, std::int64_t lda, + const std::complex *b, std::int64_t ldb, std::complex beta, + std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "gemmt", " for USM"); +} + +sycl::event omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + real_t alpha, const real_t *a, std::int64_t lda, real_t *b, std::int64_t ldb, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_omatcopy, queue, trans, m, n, alpha, a, lda, b, ldb, + dependencies); +} + +sycl::event omatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::complex *b, std::int64_t ldb, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy", "for USM"); +} + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + real_t alpha, const real_t *a, std::int64_t lda, std::int64_t stridea, + real_t *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_omatcopy2, queue, trans, m, n, alpha, a, lda, stridea, b, ldb, + strideb, dependencies); +} + +sycl::event omatcopy2(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, std::int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy2", "for USM"); +} + +sycl::event imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + real_t alpha, real_t *ab, std::int64_t lda, std::int64_t ldb, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy", ""); +} + +sycl::event imatcopy(sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, + std::complex alpha, std::complex *ab, std::int64_t lda, + std::int64_t ldb, const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy", ""); +} + +sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, real_t alpha, const real_t *a, std::int64_t lda, real_t beta, + const real_t *b, std::int64_t ldb, real_t *c, std::int64_t ldc, + const std::vector &dependencies) { + CALL_PORTBLAS_USM_FN(::blas::_omatadd, queue, transa, transb, m, n, alpha, a, lda, beta, b, ldb, + c, ldc, dependencies); +} + +sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, + std::int64_t n, std::complex alpha, const std::complex *a, + std::int64_t lda, std::complex beta, const std::complex *b, + std::int64_t ldb, std::complex *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "omatadd", ""); +} +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + real_t *alpha, const real_t **a, int64_t *lda, real_t **b, int64_t *ldb, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", ""); +} + +sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, const std::complex **a, + int64_t *lda, std::complex **b, int64_t *ldb, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy_batch", ""); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + real_t *alpha, real_t **ab, int64_t *lda, int64_t *ldb, + int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", ""); +} + +sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, + std::complex *alpha, std::complex **ab, int64_t *lda, + int64_t *ldb, int64_t group_count, int64_t *groupsize, + const std::vector &dependencies) { + throw unimplemented("blas", "imatcopy_batch", ""); +} diff --git a/src/blas/backends/portblas/portblas_level3_bfloat16.cpp b/src/blas/backends/portblas/portblas_level3_bfloat16.cpp new file mode 100644 index 000000000..1684b1b3e --- /dev/null +++ b/src/blas/backends/portblas/portblas_level3_bfloat16.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { +namespace column_major { + +// BUFFER +void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + sycl::buffer &b, std::int64_t ldb, float beta, + sycl::buffer &c, std::int64_t ldc) { + throw unimplemented("blas", "gemm", " for bfloat16"); +} + +// USM +sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const oneapi::mkl::bfloat16 *a, std::int64_t lda, const oneapi::mkl::bfloat16 *b, + std::int64_t ldb, float beta, float *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm", " for USM"); +} + +} // namespace column_major +namespace row_major { + +// BUFFER +void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, + sycl::buffer &b, std::int64_t ldb, float beta, + sycl::buffer &c, std::int64_t ldc) { + throw unimplemented("blas", "gemm", " for bfloat16"); +} + +// USM +sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + const oneapi::mkl::bfloat16 *a, std::int64_t lda, const oneapi::mkl::bfloat16 *b, + std::int64_t ldb, float beta, float *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm", " for USM"); +} + +} // namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi diff --git a/src/blas/backends/portblas/portblas_level3_double.cpp b/src/blas/backends/portblas/portblas_level3_double.cpp new file mode 100644 index 000000000..9f9d82d37 --- /dev/null +++ b/src/blas/backends/portblas/portblas_level3_double.cpp @@ -0,0 +1,60 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "portblas_common.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { + +using real_t = double; + +namespace column_major { + +#define COLUMN_MAJOR +constexpr bool is_column_major() { + return true; +} +#include "portblas_level3.cxx" +#undef COLUMN_MAJOR + +} // namespace column_major +namespace row_major { + +#define ROW_MAJOR +constexpr bool is_column_major() { + return false; +} +#include "portblas_level3.cxx" +#undef ROW_MAJOR + +} // namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi diff --git a/src/blas/backends/portblas/portblas_level3_float.cpp b/src/blas/backends/portblas/portblas_level3_float.cpp new file mode 100644 index 000000000..53a5a1697 --- /dev/null +++ b/src/blas/backends/portblas/portblas_level3_float.cpp @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "portblas_common.hpp" +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { + +using real_t = float; + +namespace column_major { + +#define COLUMN_MAJOR +constexpr bool is_column_major() { + return true; +} +#include "portblas_level3.cxx" +#include "portblas_gemm_bias.cxx" +#undef COLUMN_MAJOR + +} // namespace column_major +namespace row_major { + +#define ROW_MAJOR +constexpr bool is_column_major() { + return false; +} +#include "portblas_level3.cxx" +#include "portblas_gemm_bias.cxx" +#undef ROW_MAJOR + +} // namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi diff --git a/src/blas/backends/portblas/portblas_level3_half.cpp b/src/blas/backends/portblas/portblas_level3_half.cpp new file mode 100644 index 000000000..0e42528fa --- /dev/null +++ b/src/blas/backends/portblas/portblas_level3_half.cpp @@ -0,0 +1,103 @@ +/******************************************************************************* +* Copyright Codeplay Software +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +namespace oneapi { +namespace mkl { +namespace blas { +namespace portblas { +namespace column_major { + +// BUFFER +void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb, sycl::half beta, sycl::buffer &c, std::int64_t ldc) { + throw unimplemented("blas", "gemm", " half"); +} + +void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb, float beta, sycl::buffer &c, std::int64_t ldc) { + throw unimplemented("blas", "gemm", " for different argument data types"); +} + +// USM +sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, + const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, + sycl::half beta, sycl::half *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm", " for USM"); +} + +sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c, + std::int64_t ldc, const std::vector &dependencies) { + throw unimplemented("blas", "gemm", " for USM"); +} +} // namespace column_major +namespace row_major { + +// BUFFER +void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb, sycl::half beta, sycl::buffer &c, std::int64_t ldc) { + throw unimplemented("blas", "gemm", " half"); +} + +void gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, sycl::buffer &b, + std::int64_t ldb, float beta, sycl::buffer &c, std::int64_t ldc) { + throw unimplemented("blas", "gemm", " for different argument data types"); +} + +// USM +sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, sycl::half alpha, + const sycl::half *a, std::int64_t lda, const sycl::half *b, std::int64_t ldb, + sycl::half beta, sycl::half *c, std::int64_t ldc, + const std::vector &dependencies) { + throw unimplemented("blas", "gemm", " for USM"); +} + +sycl::event gemm(sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, const sycl::half *b, std::int64_t ldb, float beta, float *c, + std::int64_t ldc, const std::vector &dependencies) { + throw unimplemented("blas", "gemm", " for USM"); +} + +} // namespace row_major +} // namespace portblas +} // namespace blas +} // namespace mkl +} // namespace oneapi diff --git a/src/blas/backends/portblas/portblas_wrappers.cpp b/src/blas/backends/portblas/portblas_wrappers.cpp new file mode 100644 index 000000000..3f6170bb7 --- /dev/null +++ b/src/blas/backends/portblas/portblas_wrappers.cpp @@ -0,0 +1,21 @@ +// +// generated file +// + +#include "blas/function_table.hpp" + +#include "oneapi/mkl/blas/detail/portblas/onemkl_blas_portblas.hpp" + +#define WRAPPER_VERSION 1 + +extern "C" ONEMKL_EXPORT blas_function_table_t mkl_blas_table = { + WRAPPER_VERSION, +#define BACKEND portblas +#define MAJOR column_major +#include "../backend_wrappers.cxx" +#undef MAJOR +#define MAJOR row_major +#include "../backend_wrappers.cxx" +#undef MAJOR +#undef BACKEND +}; diff --git a/src/blas/backends/rocblas/CMakeLists.txt b/src/blas/backends/rocblas/CMakeLists.txt index 10b490db9..76dc126ad 100644 --- a/src/blas/backends/rocblas/CMakeLists.txt +++ b/src/blas/backends/rocblas/CMakeLists.txt @@ -21,7 +21,10 @@ set(LIB_NAME onemkl_blas_rocblas) set(LIB_OBJ ${LIB_NAME}_obj) -find_package(rocBLAS REQUIRED) +find_package(hip REQUIRED) +find_package(rocblas REQUIRED) +find_package(Threads REQUIRED) + set(SOURCES rocblas_level1.cpp rocblas_level2.cpp rocblas_level3.cpp @@ -32,14 +35,32 @@ set(SOURCES rocblas_level1.cpp $<$: rocblas_wrappers.cpp>) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemkl_backend_libs_blas ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/src ${PROJECT_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ONEMKL::rocBLAS::rocBLAS) + +if(NOT ${ONEMKL_SYCL_IMPLEMENTATION} STREQUAL "hipsycl") + target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + target_compile_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=amdgcn-amd-amdhsa -fsycl-unnamed-lambda + -Xsycl-target-backend --offload-arch=${HIP_TARGETS}) + target_link_options(ONEMKL::SYCL::SYCL INTERFACE + -fsycl-targets=amdgcn-amd-amdhsa -Xsycl-target-backend + --offload-arch=${HIP_TARGETS}) +else() + target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + target_compile_options(ONEMKL::SYCL::SYCL INTERFACE) + target_link_options(ONEMKL::SYCL::SYCL INTERFACE) +endif() + +target_link_libraries(${LIB_OBJ} PRIVATE roc::rocblas hip::host Threads::Threads) +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL) target_compile_features(${LIB_OBJ} PUBLIC cxx_std_17) set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/src/blas/backends/rocblas/rocblas_batch.cpp b/src/blas/backends/rocblas/rocblas_batch.cpp index a230ee473..5fa103055 100644 --- a/src/blas/backends/rocblas/rocblas_batch.cpp +++ b/src/blas/backends/rocblas/rocblas_batch.cpp @@ -18,12 +18,51 @@ * limitations under the License. * **************************************************************************/ + #include "rocblas_helper.hpp" #include "rocblas_task.hpp" #include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hpp" +// Helper Functions + +template +static inline void conj_vector(sycl::handler &cgh, sycl::buffer &buf, const int64_t len, + const int64_t inc, const int64_t stride, const int64_t batch_size) { + const auto abs_inc = std::abs(inc); + const auto abs_stride = std::abs(stride); + auto acc = buf.template get_access(cgh); + cgh.parallel_for(sycl::range{ (std::size_t)batch_size, (std::size_t)len }, + [=](sycl::item<2> it) { + const auto index = it.get_id(0) * abs_stride + it.get_id(1) * abs_inc; + acc[index] = std::conj(acc[index]); + }); +} +template +static inline void conj_vector(sycl::handler &cgh, T *ptr, const int64_t len, const int64_t inc, + const int64_t stride, const int64_t batch_size) { + const auto abs_inc = std::abs(inc); + const auto abs_stride = std::abs(stride); + cgh.parallel_for(sycl::range{ (std::size_t)batch_size, (std::size_t)len }, + [=](sycl::item<2> it) { + const auto index = it.get_id(0) * abs_stride + it.get_id(1) * abs_inc; + ptr[index] = std::conj(ptr[index]); + }); +} + +template +static inline void conj_vector(sycl::handler &cgh, T **ptr, const int64_t len, const int64_t inc, + const int64_t stride, const int64_t group_size) { + const auto abs_inc = std::abs(inc); + cgh.parallel_for(sycl::range{ (std::size_t)group_size, (std::size_t)len }, + [=](sycl::item<2> it) { + const auto col = it.get_id(0) + stride; + const auto row = it.get_id(1) * abs_inc; + ptr[col][row] = std::conj(ptr[col][row]); + }); +} + namespace oneapi { namespace mkl { namespace blas { @@ -32,245 +71,365 @@ namespace column_major { // Buffer APIs -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} +template +inline void copy_batch(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, + int64_t stridey, int64_t batch_size) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(n, incx, incy, stridex, stridey, batch_size); -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} + queue.submit([&](sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, int64_t stridex, sycl::buffer, 1> &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + auto x_ = sc.get_mem(x_acc); + auto y_ = sc.get_mem(y_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, x_, incx, stridex, y_, incy, stridey, + batch_size); + }); + }); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, int64_t stridex, sycl::buffer, 1> &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} +#define COPY_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, \ + int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, \ + int64_t batch_size) { \ + copy_batch(ROCBLAS_ROUTINE, queue, n, x, incx, stridex, y, incy, stridey, batch_size); \ + } -void axpy_batch(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} +COPY_STRIDED_BATCH_LAUNCHER(float, rocblas_scopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER(double, rocblas_dcopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_ccopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zcopy_strided_batched) -void axpy_batch(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, - int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, - int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} +#undef COPY_STRIDED_BATCH_LAUNCHER -void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, int64_t stridex, - sycl::buffer, 1> &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} +template +inline void axpy_batch(Func func, sycl::queue &queue, int64_t n, T alpha, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, + int64_t stridey, int64_t batch_size) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(n, incx, incy, stridex, stridey, batch_size); -void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, int64_t stridex, - sycl::buffer, 1> &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} + queue.submit([&](sycl::handler &cgh) { + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, - int64_t incx, int64_t stride_x, float beta, sycl::buffer &y, int64_t incy, - int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + auto x_ = sc.get_mem(x_acc); + auto y_ = sc.get_mem(y_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, (rocDataType *)&alpha, x_, incx, stridex, + y_, incy, stridey, batch_size); + }); + }); } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &x, int64_t incx, int64_t stride_x, double beta, - sycl::buffer &y, int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +#define AXPY_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void axpy_batch(sycl::queue &queue, int64_t n, TYPE alpha, sycl::buffer &x, \ + int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, \ + int64_t stridey, int64_t batch_size) { \ + axpy_batch(ROCBLAS_ROUTINE, queue, n, alpha, x, incx, stridex, y, incy, stridey, \ + batch_size); \ + } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, - int64_t stride_x, std::complex beta, sycl::buffer, 1> &y, - int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +AXPY_STRIDED_BATCH_LAUNCHER(float, rocblas_saxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER(double, rocblas_daxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_caxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zaxpy_strided_batched) -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, - int64_t stride_x, std::complex beta, - sycl::buffer, 1> &y, int64_t incy, int64_t stride_y, - int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +#undef AXPY_BATCH_LAUNCHER -void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, - int64_t incx, int64_t stride_x, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); -} +template +inline void gemv_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + T alpha, sycl::buffer &a, int64_t lda, int64_t stridea, + sycl::buffer &x, int64_t incx, int64_t stridex, T beta, + sycl::buffer &y, int64_t incy, int64_t stridey, int64_t batch_size) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, incx, incy, stridea, stridex, stridey, batch_size); -void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &x, int64_t incx, int64_t stride_x, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); -} + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto x_acc = x.template get_access(cgh); + auto y_acc = y.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + auto a_ = sc.get_mem(a_acc); + auto x_ = sc.get_mem(x_acc); + auto y_ = sc.get_mem(y_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans), m, n, + (rocDataType *)&alpha, a_, lda, stridea, x_, incx, stridex, + (rocDataType *)&beta, y_, incy, stridey, batch_size); + }); + }); } -void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); -} +#define GEMV_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void gemv_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &x, int64_t incx, int64_t stridex, TYPE beta, \ + sycl::buffer &y, int64_t incy, int64_t stridey, int64_t batch_size) { \ + gemv_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, stridea, x, incx, stridex, \ + beta, y, incy, stridey, batch_size); \ + } + +GEMV_STRIDED_BATCH_LAUNCHER(float, rocblas_sgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER(double, rocblas_dgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgemv_strided_batched) + +#undef GEMV_STRIDED_BATCH_LAUNCHER template -inline void gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, - T beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { +inline void dgmm_batch(Func func, sycl::queue &queue, side left_right, int64_t m, int64_t n, + sycl::buffer &a, int64_t lda, int64_t stridea, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &c, int64_t ldc, + int64_t stridec, int64_t batch_size) { using rocDataType = typename RocEquivalentType::Type; - overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + overflow_check(m, n, lda, ldc, incx, stridea, stridex, stridec, batch_size); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); - auto b_acc = b.template get_access(cgh); + auto x_acc = x.template get_access(cgh); auto c_acc = c.template get_access(cgh); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - auto a_ = sc.get_mem(a_acc); - auto b_ = sc.get_mem(b_acc); + auto a_ = sc.get_mem(a_acc); + auto x_ = sc.get_mem(x_acc); auto c_ = sc.get_mem(c_acc); rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, k, (rocDataType *)&alpha, - a_, lda, stride_a, b_, ldb, stride_b, (rocDataType *)&beta, c_, - ldc, stride_c, batch_size); + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), m, n, a_, + lda, stridea, x_, incx, stridex, c_, ldc, stridec, batch_size); + }); + }); +} + +#define DGMM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, \ + sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &x, int64_t incx, int64_t stridex, \ + sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { \ + dgmm_batch(ROCBLAS_ROUTINE, queue, left_right, m, n, a, lda, stridea, x, incx, stridex, c, \ + ldc, stridec, batch_size); \ + } + +DGMM_STRIDED_BATCH_LAUNCHER(float, rocblas_sdgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER(double, rocblas_ddgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cdgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zdgmm_strided_batched) + +#undef DGMM_STRIDED_BATCH_LAUNCHER + +template +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { + using rocTypeA = typename RocEquivalentType::Type; + using rocTypeB = typename RocEquivalentType::Type; + using rocTypeC = typename RocEquivalentType::Type; + using rocTypeS = typename RocEquivalentType::Type; + overflow_check(m, n, k, lda, ldb, ldc, stridea, strideb, stridec, batch_size); + + int32_t solution_index = 0; + rocblas_gemm_flags flags = rocblas_gemm_flags_none; + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_gemm_strided_batched_ex, err, handle, + get_rocblas_operation(transa), get_rocblas_operation(transb), m, + n, k, &alpha, a_, get_rocblas_datatype(), lda, + stridea, b_, get_rocblas_datatype(), ldb, strideb, + &beta, c_, get_rocblas_datatype(), ldc, stridec, c_, + get_rocblas_datatype(), ldc, stridec, batch_size, + get_rocblas_datatype(), rocblas_gemm_algo_standard, + solution_index, flags); }); }); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ int64_t batch_size) { \ - gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, \ - ldb, stride_b, beta, c, ldc, stride_c, batch_size); \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, \ + beta, c, ldc, stridec, batch_size); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) #undef GEMM_STRIDED_BATCH_LAUNCHER -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); -} +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, double alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); -} +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); -} +#undef GEMM_STRIDED_BATCH_LAUNCHER -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); -} +template +inline void trsm_batch(Func func, sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, + sycl::buffer &a, int64_t lda, int64_t stridea, sycl::buffer &b, + int64_t ldb, int64_t strideb, int64_t batch_size) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, stridea, strideb, batch_size); -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, float beta, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), + get_rocblas_fill_mode(upper_lower), + get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag), + m, n, (rocDataType *)&alpha, a_, lda, stridea, b_, ldb, strideb, + batch_size); + }); + }); } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, sycl::buffer, 1> &c, - int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} +#define TRSM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \ + diag unit_diag, int64_t m, int64_t n, TYPE alpha, sycl::buffer &a, \ + int64_t lda, int64_t stridea, sycl::buffer &b, int64_t ldb, \ + int64_t strideb, int64_t batch_size) { \ + trsm_batch(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, \ + a, lda, stridea, b, ldb, strideb, batch_size); \ + } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); -} +TRSM_STRIDED_BATCH_LAUNCHER(float, rocblas_strsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER(double, rocblas_dtrsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_ctrsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_ztrsm_strided_batched) -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); -} +#undef TRSM_STRIDED_BATCH_LAUNCHER + +template +inline void syrk_batch(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, + int64_t k, T alpha, sycl::buffer &a, int64_t lda, int64_t stridea, + T beta, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(n, k, lda, ldc, stridea, stridec, batch_size); + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + auto a_ = sc.get_mem(a_acc); + auto c_ = sc.get_mem(c_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_fill_mode(upper_lower), + get_rocblas_operation(trans), n, k, (rocDataType *)&alpha, a_, + lda, stridea, (rocDataType *)&beta, c_, ldc, stridec, + batch_size); + }); + }); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, - int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#define SYRK_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, \ + TYPE alpha, sycl::buffer &a, int64_t lda, int64_t stridea, TYPE beta, \ + sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { \ + syrk_batch(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, stridea, beta, \ + c, ldc, stridec, batch_size); \ + } + +SYRK_STRIDED_BATCH_LAUNCHER(float, rocblas_ssyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER(double, rocblas_dsyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_csyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zsyrk_strided_batched) + +#undef SYRK_STRIDED_BATCH_LAUNCHER + +template +inline void omatcopy_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + const T alpha, sycl::buffer &a, int64_t lda, int64_t stridea, + sycl::buffer &b, int64_t ldb, int64_t strideb, + int64_t batch_size) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, stridea, strideb, batch_size); + + const T beta = 0; + const int64_t new_m = trans == oneapi::mkl::transpose::nontrans ? m : n; + const int64_t new_n = trans == oneapi::mkl::transpose::nontrans ? n : m; + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), new_m, new_n, + (rocDataType *)&alpha, a_, lda, stridea, (rocDataType *)&beta, + nullptr, lda, stridea, b_, ldb, strideb, batch_size); + }); + }); } +#define OMATCOPY_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, \ + const TYPE alpha, sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &b, int64_t ldb, int64_t strideb, \ + int64_t batch_size) { \ + omatcopy_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, stridea, b, ldb, \ + strideb, batch_size); \ + } + +OMATCOPY_STRIDED_BATCH_LAUNCHER(float, rocblas_sgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER(double, rocblas_dgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgeam_strided_batched) + +#undef OMATCOPY_STRIDED_BATCH_LAUNCHER + void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size) { @@ -295,392 +454,604 @@ void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); -} +template +inline void omatadd_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, const T alpha, sycl::buffer &a, int64_t lda, + int64_t stridea, const T beta, sycl::buffer &b, int64_t ldb, + int64_t strideb, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc, stridea, strideb, stridec, batch_size); -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); -} + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_, + lda, stridea, (rocDataType *)&beta, b_, ldb, strideb, c_, ldc, + stridec, batch_size); + }); + }); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); -} +#define OMATADD_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, const TYPE alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, const TYPE beta, sycl::buffer &b, int64_t ldb, \ + int64_t strideb, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + omatadd_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, stridea, beta, \ + b, ldb, strideb, c, ldc, stridec, batch_size); \ + } + +OMATADD_STRIDED_BATCH_LAUNCHER(float, rocblas_sgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER(double, rocblas_dgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgeam_strided_batched) + +#undef OMATADD_STRIDED_BATCH_LAUNCHER // USM APIs -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, int64_t *incx, double **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} +template +inline sycl::event copy_batch(Func func, sycl::queue &queue, int64_t *n, const T **x, int64_t *incx, + T **y, int64_t *incy, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(n[i], incx[i], incy[i], group_size[i]); + } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, - int64_t *incx, std::complex **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} + int64_t offset = 0; + rocblas_status err; + for (int64_t i = 0; i < group_count; i++) { + auto **x_ = reinterpret_cast(x); + auto **y_ = reinterpret_cast(y); + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, (int)n[i], x_ + offset, (int)incx[i], + y_ + offset, (int)incy[i], (int)group_size[i]); + offset += group_size[i]; + } + }); + }); -sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, int64_t incx, int64_t stridex, - float *y, int64_t incy, int64_t stridey, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); + return done; } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, int64_t incx, - int64_t stridex, double *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} +#define COPY_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event copy_batch(sycl::queue &queue, int64_t *n, const TYPE **x, int64_t *incx, \ + TYPE **y, int64_t *incy, int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return copy_batch(ROCBLAS_ROUTINE, queue, n, x, incx, y, incy, group_count, group_size, \ + dependencies); \ + } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - int64_t stridex, std::complex *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} +COPY_BATCH_LAUNCHER_USM(float, rocblas_scopy_batched) +COPY_BATCH_LAUNCHER_USM(double, rocblas_dcopy_batched) +COPY_BATCH_LAUNCHER_USM(std::complex, rocblas_ccopy_batched) +COPY_BATCH_LAUNCHER_USM(std::complex, rocblas_zcopy_batched) -sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - int64_t stridex, std::complex *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for column_major layout"); -} +#undef COPY_BATCH_LAUNCHER_USM -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, const float **x, int64_t *incx, - float **y, int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} +template +inline sycl::event copy_batch(Func func, sycl::queue &queue, int64_t n, const T *x, int64_t incx, + int64_t stridex, T *y, int64_t incy, int64_t stridey, + int64_t batch_size, const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(n, incx, incy, stridex, stridey, batch_size); -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, const double **x, - int64_t *incx, double **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, - const std::complex **x, int64_t *incx, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} + auto x_ = reinterpret_cast(x); + auto y_ = reinterpret_cast(y); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, x_, incx, stridex, y_, incy, stridey, + batch_size); + }); + }); -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, - const std::complex **x, int64_t *incx, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); + return done; } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, - int64_t stridex, float *y, int64_t incy, int64_t stridey, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} +#define COPY_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event copy_batch(sycl::queue &queue, int64_t n, const TYPE *x, int64_t incx, \ + int64_t stridex, TYPE *y, int64_t incy, int64_t stridey, \ + int64_t batch_size, const std::vector &dependencies) { \ + return copy_batch(ROCBLAS_ROUTINE, queue, n, x, incx, stridex, y, incy, stridey, \ + batch_size, dependencies); \ + } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, const double *x, int64_t incx, - int64_t stridex, double *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} +COPY_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_scopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dcopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_ccopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zcopy_strided_batched) -sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, int64_t stridex, - std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} +#undef COPY_STRIDED_BATCH_LAUNCHER_USM -sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, int64_t stridex, - std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for column_major layout"); -} +template +inline sycl::event axpy_batch(Func func, sycl::queue &queue, int64_t *n, T *alpha, const T **x, + int64_t *incx, T **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(n[i], incx[i], incy[i], group_size[i]); + } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float beta, float *y, int64_t incy, int64_t stride_y, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, const double *x, - int64_t incx, int64_t stride_x, double beta, double *y, int64_t incy, - int64_t stride_y, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} + int64_t offset = 0; + rocblas_status err; + for (int64_t i = 0; i < group_count; i++) { + auto **x_ = reinterpret_cast(x); + auto **y_ = reinterpret_cast(y); + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, (int)n[i], (rocDataType *)&alpha[i], + x_ + offset, (int)incx[i], y_ + offset, (int)incy[i], + (int)group_size[i]); + offset += group_size[i]; + } + }); + }); -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, int64_t incx, - int64_t stride_x, std::complex beta, std::complex *y, - int64_t incy, int64_t stride_y, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); + return done; } -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, int64_t incx, - int64_t stride_x, std::complex beta, std::complex *y, - int64_t incy, int64_t stride_y, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +#define AXPY_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event axpy_batch(sycl::queue &queue, int64_t *n, TYPE *alpha, const TYPE **x, \ + int64_t *incx, TYPE **y, int64_t *incy, int64_t group_count, \ + int64_t *group_size, const std::vector &dependencies) { \ + return axpy_batch(ROCBLAS_ROUTINE, queue, n, alpha, x, incx, y, incy, group_count, \ + group_size, dependencies); \ + } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, float *alpha, - const float **a, int64_t *lda, const float **x, int64_t *incx, float *beta, - float **y, int64_t *incy, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +AXPY_BATCH_LAUNCHER_USM(float, rocblas_saxpy_batched) +AXPY_BATCH_LAUNCHER_USM(double, rocblas_daxpy_batched) +AXPY_BATCH_LAUNCHER_USM(std::complex, rocblas_caxpy_batched) +AXPY_BATCH_LAUNCHER_USM(std::complex, rocblas_zaxpy_batched) -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, double *alpha, - const double **a, int64_t *lda, const double **x, int64_t *incx, - double *beta, double **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +#undef AXPY_BATCH_LAUNCHER_USM -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} +template +inline sycl::event axpy_batch(Func func, sycl::queue &queue, int64_t n, T alpha, const T *x, + int64_t incx, int64_t stridex, T *y, int64_t incy, int64_t stridey, + int64_t batch_size, const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(n, incx, incy, stridex, stridey, batch_size); -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for column_major layout"); -} + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const float *a, - int64_t lda, int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); -} + auto x_ = reinterpret_cast(x); + auto y_ = reinterpret_cast(y); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, (rocDataType *)&alpha, x_, incx, stridex, + y_, incy, stridey, batch_size); + }); + }); -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const double *a, - int64_t lda, int64_t stride_a, const double *x, int64_t incx, - int64_t stride_x, double *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + return done; } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - const std::complex *a, int64_t lda, int64_t stride_a, - const std::complex *x, int64_t incx, int64_t stride_x, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); -} +#define AXPY_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event axpy_batch(sycl::queue &queue, int64_t n, TYPE alpha, const TYPE *x, int64_t incx, \ + int64_t stridex, TYPE *y, int64_t incy, int64_t stridey, \ + int64_t batch_size, const std::vector &dependencies) { \ + return axpy_batch(ROCBLAS_ROUTINE, queue, n, alpha, x, incx, stridex, y, incy, stridey, \ + batch_size, dependencies); \ + } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - const std::complex *a, int64_t lda, int64_t stride_a, - const std::complex *x, int64_t incx, int64_t stride_x, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); -} +AXPY_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_saxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_daxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_caxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zaxpy_strided_batched) -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const float **a, int64_t *lda, const float **x, int64_t *incx, float **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); -} +#undef AXPY_STRIDED_BATCH_LAUNCHER_USM -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const double **a, int64_t *lda, const double **x, int64_t *incx, double **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); -} +template +inline sycl::event gemv_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + T alpha, const T *a, int64_t lda, int64_t stridea, const T *x, + int64_t incx, int64_t stridex, T beta, T *y, int64_t incy, + int64_t stridey, int64_t batch_size, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(n, m, lda, incx, incy, stridea, stridex, stridey, batch_size); + + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const std::complex **a, int64_t *lda, const std::complex **x, - int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); + auto a_ = reinterpret_cast(a); + auto x_ = reinterpret_cast(x); + auto y_ = reinterpret_cast(y); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans), m, n, + (rocDataType *)&alpha, a_, lda, stridea, x_, incx, stridex, + (rocDataType *)&beta, y_, incy, stridey, batch_size); + }); + }); + + return done; } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const std::complex **a, int64_t *lda, const std::complex **x, - int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for column_major layout"); +#define GEMV_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event gemv_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, int64_t stridea, const TYPE *x, \ + int64_t incx, int64_t stridex, TYPE beta, TYPE *y, int64_t incy, \ + int64_t stridey, int64_t batch_size, \ + const std::vector &dependencies) { \ + return gemv_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, stridea, x, incx, \ + stridex, beta, y, incy, stridey, batch_size, dependencies); \ + } + +GEMV_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemv_strided_batched) + +#undef GEMV_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemv_batch(Func func, sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, T *alpha, const T **a, int64_t *lda, const T **x, + int64_t *incx, T *beta, T **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(m[i], n[i], lda[i], incx[i], incy[i], group_size[i]); + } + + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + int64_t offset = 0; + rocblas_status err; + for (int64_t i = 0; i < group_count; i++) { + auto **a_ = reinterpret_cast(a); + auto **x_ = reinterpret_cast(x); + auto **y_ = reinterpret_cast(y); + ROCBLAS_ERROR_FUNC_SYNC( + func, err, handle, get_rocblas_operation(trans[i]), (int)m[i], (int)n[i], + (rocDataType *)&alpha[i], a_ + offset, (int)lda[i], x_ + offset, (int)incx[i], + (rocDataType *)&beta[i], y_ + offset, (int)incy[i], (int)group_size[i]); + offset += group_size[i]; + } + }); + }); + + return done; } +#define GEMV_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event gemv_batch( \ + sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, TYPE *alpha, const TYPE **a, \ + int64_t *lda, const TYPE **x, int64_t *incx, TYPE *beta, TYPE **y, int64_t *incy, \ + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { \ + return gemv_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, x, incx, beta, y, \ + incy, group_count, group_size, dependencies); \ + } + +GEMV_BATCH_LAUNCHER_USM(float, rocblas_sgemv_batched) +GEMV_BATCH_LAUNCHER_USM(double, rocblas_dgemv_batched) +GEMV_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemv_batched) +GEMV_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemv_batched) + +#undef GEMV_BATCH_LAUNCHER_USM + template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, - int64_t stride_a, const T *b, int64_t ldb, int64_t stride_b, T beta, - T *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { +inline sycl::event dgmm_batch(Func func, sycl::queue &queue, side left_right, int64_t m, int64_t n, + const T *a, int64_t lda, int64_t stridea, const T *x, int64_t incx, + int64_t stridex, T *c, int64_t ldc, int64_t stridec, + int64_t batch_size, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; - overflow_check(m, n, k, lda, ldb, ldc, stride_a, stride_b, stride_c, batch_size); + overflow_check(m, n, incx, stridea, stridex, stridec, batch_size); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); - auto b_ = reinterpret_cast(b); + auto x_ = reinterpret_cast(x); auto c_ = reinterpret_cast(c); rocblas_status err; - ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, k, (rocDataType *)&alpha, - a_, lda, stride_a, b_, ldb, stride_b, (rocDataType *)&beta, c_, - ldc, stride_c, batch_size); + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), m, n, a_, + lda, stridea, x_, incx, stridex, c_, ldc, stridec, batch_size); }); }); + return done; } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stride_a, const TYPE *b, int64_t ldb, int64_t stride_b, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stride_c, int64_t batch_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, \ - stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, \ - dependencies); \ +#define DGMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, \ + const TYPE *a, int64_t lda, int64_t stridea, const TYPE *x, \ + int64_t incx, int64_t stridex, TYPE *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return dgmm_batch(ROCBLAS_ROUTINE, queue, left_right, m, n, a, lda, stridea, x, incx, \ + stridex, c, ldc, stridec, batch_size, dependencies); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sdgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_ddgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cdgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zdgmm_strided_batched) -#undef GEMM_STRIDED_BATCH_LAUNCHER_USM +#undef DGMM_STRIDED_BATCH_LAUNCHER_USM template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, transpose *transb, - int64_t *m, int64_t *n, int64_t *k, T *alpha, const T **a, - int64_t *lda, const T **b, int64_t *ldb, T *beta, T **c, int64_t *ldc, - int64_t group_count, int64_t *group_size, +inline sycl::event dgmm_batch(Func func, sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const T **a, int64_t *lda, const T **x, int64_t *incx, + T **c, int64_t *ldc, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; for (int64_t i = 0; i < group_count; i++) { - overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); + overflow_check(m[i], n[i], lda[i], ldc[i], incx[i], group_size[i]); } + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); - int64_t offset = 0; rocblas_status err; + for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a); - auto **b_ = reinterpret_cast(b); + auto **x_ = reinterpret_cast(x); auto **c_ = reinterpret_cast(c); - ROCBLAS_ERROR_FUNC_SYNC( - func, err, handle, get_rocblas_operation(transa[i]), - get_rocblas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], - (rocDataType *)&alpha[i], a_ + offset, (int)lda[i], b_ + offset, (int)ldb[i], - (rocDataType *)&beta[i], c_ + offset, (int)ldc[i], (int)group_size[i]); + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right[i]), + (int)m[i], (int)n[i], a_ + offset, (int)lda[i], x_ + offset, + (int)incx[i], c_ + offset, (int)ldc[i], (int)group_size[i]); offset += group_size[i]; } }); }); + return done; } -#define GEMM_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ +#define DGMM_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, \ + const TYPE **a, int64_t *lda, const TYPE **x, int64_t *incx, TYPE **c, \ + int64_t *ldc, int64_t group_count, int64_t *group_size, \ const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, \ - beta, c, ldc, group_count, group_size, dependencies); \ + return dgmm_batch(ROCBLAS_ROUTINE, queue, left_right, m, n, a, lda, x, incx, c, ldc, \ + group_count, group_size, dependencies); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_batched) -GEMM_BATCH_LAUNCHER_USM(float, rocblas_sgemm_batched) -GEMM_BATCH_LAUNCHER_USM(double, rocblas_dgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_batched) +DGMM_BATCH_LAUNCHER_USM(float, rocblas_sdgmm_batched) +DGMM_BATCH_LAUNCHER_USM(double, rocblas_ddgmm_batched) +DGMM_BATCH_LAUNCHER_USM(std::complex, rocblas_cdgmm_batched) +DGMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zdgmm_batched) + +#undef DGMM_BATCH_LAUNCHER + +template +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stridea, + const Tb *b, int64_t ldb, int64_t strideb, Ts beta, + Tc *c, int64_t ldc, int64_t stridec, + int64_t batch_size, + const std::vector &dependencies) { + using rocTypeA = typename RocEquivalentType::Type; + using rocTypeB = typename RocEquivalentType::Type; + using rocTypeC = typename RocEquivalentType::Type; + using rocTypeS = typename RocEquivalentType::Type; + overflow_check(m, n, k, lda, ldb, ldc, stridea, strideb, stridec, batch_size); + + int32_t solution_index = 0; + rocblas_gemm_flags flags = rocblas_gemm_flags_none; + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -#undef GEMM_BATCH_LAUNCHER_USM + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(rocblas_gemm_strided_batched_ex, err, handle, + get_rocblas_operation(transa), get_rocblas_operation(transb), m, + n, k, &alpha, a_, get_rocblas_datatype(), lda, + stridea, b_, get_rocblas_datatype(), ldb, strideb, + &beta, c_, get_rocblas_datatype(), ldc, stridec, c_, + get_rocblas_datatype(), ldc, stridec, batch_size, + get_rocblas_datatype(), rocblas_gemm_algo_standard, + solution_index, flags); + }); + }); -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, - int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + return done; } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, - int64_t lda, int64_t stride_a, double *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); -} +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ + b, ldb, strideb, beta, c, ldc, stridec, batch_size, \ + dependencies); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using rocTypeA = typename RocEquivalentType::Type; + using rocTypeB = typename RocEquivalentType::Type; + using rocTypeC = typename RocEquivalentType::Type; + using rocTypeS = typename RocEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(m[i], n[i], k[i], lda[i], ldb[i], ldc[i], group_size[i]); + } + + int32_t solution_index = 0; + rocblas_gemm_flags flags = rocblas_gemm_flags_none; + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + int64_t offset = 0; + rocblas_status err; + for (int64_t i = 0; i < group_count; i++) { + auto **a_ = reinterpret_cast(a); + auto **b_ = reinterpret_cast(b); + auto **c_ = reinterpret_cast(c); + ROCBLAS_ERROR_FUNC_SYNC( + rocblas_gemm_batched_ex, err, handle, get_rocblas_operation(transa[i]), + get_rocblas_operation(transb[i]), (int)m[i], (int)n[i], (int)k[i], &alpha[i], + a_ + offset, get_rocblas_datatype(), (int)lda[i], b_ + offset, + get_rocblas_datatype(), (int)ldb[i], &beta[i], c_ + offset, + get_rocblas_datatype(), (int)ldc[i], c_ + offset, + get_rocblas_datatype(), (int)ldc[i], (int)group_size[i], + get_rocblas_datatype(), rocblas_gemm_algo_standard, solution_index, + flags); + offset += group_size[i]; + } + }); + }); -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, int64_t stride_a, - std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); + return done; } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, int64_t stride_a, - std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for column_major layout"); +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ + } + +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) + +#undef GEMM_BATCH_LAUNCHER_USM + +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_BATCH_LAUNCHER_USM + +template +inline sycl::event trsm_batch(Func func, sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, + const T *a, int64_t lda, int64_t stridea, T *b, int64_t ldb, + int64_t strideb, int64_t batch_size, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, stridea, strideb, batch_size); + + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), + get_rocblas_fill_mode(upper_lower), + get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag), + m, n, (rocDataType *)&alpha, a_, lda, stridea, b_, ldb, strideb, + batch_size); + }); + }); + + return done; } +#define TRSM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \ + diag unit_diag, int64_t m, int64_t n, TYPE alpha, const TYPE *a, \ + int64_t lda, int64_t stridea, TYPE *b, int64_t ldb, int64_t strideb, \ + int64_t batch_size, const std::vector &dependencies) { \ + return trsm_batch(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, \ + alpha, a, lda, stridea, b, ldb, strideb, batch_size, dependencies); \ + } + +TRSM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_strsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dtrsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_ctrsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_ztrsm_strided_batched) + +#undef TRSM_STRIDED_BATCH_LAUNCHER_USM + template inline sycl::event trsm_batch(Func func, sycl::queue &queue, side *left_right, uplo *upper_lower, transpose *trans, diag *unit_diag, int64_t *m, int64_t *n, T *alpha, @@ -690,15 +1061,14 @@ inline sycl::event trsm_batch(Func func, sycl::queue &queue, side *left_right, u for (int64_t i = 0; i < group_count; i++) { overflow_check(m[i], n[i], lda[i], ldb[i], group_size[i]); } + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); int64_t offset = 0; rocblas_status err; + for (int64_t i = 0; i < group_count; i++) { auto **a_ = reinterpret_cast(a); auto **b_ = reinterpret_cast(b); @@ -712,6 +1082,7 @@ inline sycl::event trsm_batch(Func func, sycl::queue &queue, side *left_right, u } }); }); + return done; } @@ -732,94 +1103,145 @@ TRSM_BATCH_LAUNCHER_USM(std::complex, rocblas_ztrsm_batched) #undef TRSM_BATCH_LAUNCHER_USM -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, float *alpha, const float **a, int64_t *lda, float *beta, - float **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} +template +inline sycl::event syrk_batch(Func func, sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, T *alpha, const T **a, int64_t *lda, T *beta, + T **c, int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(n[i], k[i], lda[i], ldc[i], group_size[i]); + } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, double *alpha, const double **a, int64_t *lda, double *beta, - double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int64_t offset = 0; + rocblas_status err; -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex *beta, std::complex **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} + for (int64_t i = 0; i < group_count; i++) { + auto **a_ = reinterpret_cast(a); + auto **c_ = reinterpret_cast(c); + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_fill_mode(upper_lower[i]), + get_rocblas_operation(trans[i]), (int)n[i], (int)k[i], + (rocDataType *)&alpha[i], a_ + offset, (int)lda[i], + (rocDataType *)&beta[i], c_ + offset, (int)ldc[i], + (int)group_size[i]); + offset += group_size[i]; + } + }); + }); -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex *beta, std::complex **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); + return done; } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t stride_a, float beta, - float *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} +#define SYRK_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, \ + int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, TYPE *beta, \ + TYPE **c, int64_t *ldc, int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return syrk_batch(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, beta, \ + c, ldc, group_count, group_size, dependencies); \ + } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t stride_a, double beta, - double *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} +SYRK_BATCH_LAUNCHER_USM(float, rocblas_ssyrk_batched) +SYRK_BATCH_LAUNCHER_USM(double, rocblas_dsyrk_batched) +SYRK_BATCH_LAUNCHER_USM(std::complex, rocblas_csyrk_batched) +SYRK_BATCH_LAUNCHER_USM(std::complex, rocblas_zsyrk_batched) -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} +#undef SYRK_BATCH_LAUNCHER_USM -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for column_major layout"); -} +template +inline sycl::event syrk_batch(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, const T alpha, const T *a, int64_t lda, + int64_t stridea, const T beta, T *c, int64_t ldc, int64_t stridec, + int64_t batch_size, const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(n, k, lda, ldc, stridea, stridec, batch_size); -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); -} + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); -} + auto a_ = reinterpret_cast(a); + auto c_ = reinterpret_cast(c); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_fill_mode(upper_lower), + get_rocblas_operation(trans), n, k, (rocDataType *)&alpha, a_, + lda, stridea, (rocDataType *)&beta, c_, ldc, stridec, + batch_size); + }); + }); -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + return done; } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); +#define SYRK_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, \ + int64_t k, const TYPE alpha, const TYPE *a, int64_t lda, \ + int64_t stridea, const TYPE beta, TYPE *c, int64_t ldc, \ + int64_t stridec, int64_t batch_size, \ + const std::vector &dependencies) { \ + return syrk_batch(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, \ + stridea, beta, c, ldc, stridec, batch_size, dependencies); \ + } + +SYRK_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_ssyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dsyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_csyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zsyrk_strided_batched) + +#undef SYRK_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event omatcopy_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, const T alpha, const T *a, int64_t lda, + int64_t stridea, T *b, int64_t ldb, int64_t strideb, + int64_t batch_size, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, stridea, strideb, batch_size); + + const T beta = 0; + const int64_t new_m = trans == oneapi::mkl::transpose::nontrans ? m : n; + const int64_t new_n = trans == oneapi::mkl::transpose::nontrans ? n : m; + + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), new_m, new_n, + (rocDataType *)&alpha, a_, lda, stridea, (rocDataType *)&beta, + nullptr, lda, stridea, b_, ldb, strideb, batch_size); + }); + }); + + return done; } +#define OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, \ + const TYPE alpha, const TYPE *a, int64_t lda, int64_t stridea, \ + TYPE *b, int64_t ldb, int64_t strideb, int64_t batch_size, \ + const std::vector &dependencies) { \ + return omatcopy_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, stridea, b, ldb, \ + strideb, batch_size, dependencies); \ + } + +OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgeam_strided_batched) + +#undef OMATCOPY_STRIDED_BATCH_LAUNCHER_USM + sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, const std::vector &dependencies) { @@ -846,321 +1268,408 @@ sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64 throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, - float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); -} +template +inline sycl::event omatadd_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, const T alpha, const T *a, int64_t lda, + int64_t stridea, const T beta, const T *b, int64_t ldb, + int64_t strideb, T *c, int64_t ldc, int64_t stridec, + int64_t batch_size, const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc, stridea, strideb, stridec, batch_size); + + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_, + lda, stridea, (rocDataType *)&beta, b_, ldb, strideb, c_, ldc, + stridec, batch_size); + }); + }); + + return done; +} + +#define OMATADD_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, const TYPE alpha, const TYPE *a, int64_t lda, \ + int64_t stridea, const TYPE beta, const TYPE *b, int64_t ldb, \ + int64_t strideb, TYPE *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return omatadd_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, stridea, \ + beta, b, ldb, strideb, c, ldc, stridec, batch_size, dependencies); \ + } + +OMATADD_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgeam_strided_batched) + +#undef OMATADD_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event omatcopy_batch(Func func, sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, T *alpha, const T **a, int64_t *lda, T **b, + int64_t *ldb, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + for (int64_t i = 0; i < group_count; i++) { + overflow_check(m[i], n[i], lda[i], ldb[i], group_size[i]); + } + + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + int64_t offset = 0; + rocblas_status err; -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, - double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); -} + for (int64_t i = 0; i < group_count; i++) { + auto **a_ = reinterpret_cast(a); + auto **b_ = reinterpret_cast(b); -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, int64_t stride_b, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); -} + const T beta = 0; + const auto new_m = trans[i] == oneapi::mkl::transpose::nontrans ? m[i] : n[i]; + const auto new_n = trans[i] == oneapi::mkl::transpose::nontrans ? n[i] : m[i]; -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, int64_t stride_b, - std::complex *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for column_major layout"); -} + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans[i]), + get_rocblas_operation(trans[i]), (int)new_m, (int)new_n, + (rocDataType *)&alpha[i], a_ + offset, (int)lda[i], + (rocDataType *)&beta, nullptr, (int)lda[i], b_ + offset, + (int)ldb[i], (int)group_size[i]); + offset += group_size[i]; + } + }); + }); -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - float *alpha, const float **a, int64_t *lda, float **b, int64_t *ldb, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); + return done; } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - double *alpha, const double **a, int64_t *lda, double **b, int64_t *ldb, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); -} +#define OMATCOPY_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, \ + TYPE *alpha, const TYPE **a, int64_t *lda, TYPE **b, int64_t *ldb, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return omatcopy_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb, \ + group_count, group_size, dependencies); \ + } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - std::complex **b, int64_t *ldb, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); -} +OMATCOPY_BATCH_LAUNCHER_USM(float, rocblas_sgeam_batched) +OMATCOPY_BATCH_LAUNCHER_USM(double, rocblas_dgeam_batched) +OMATCOPY_BATCH_LAUNCHER_USM(std::complex, rocblas_cgeam_batched) +OMATCOPY_BATCH_LAUNCHER_USM(std::complex, rocblas_zgeam_batched) -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex **b, int64_t *ldb, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for column_major layout"); -} +#undef OMATCOPY_BATCH_LAUNCHER_USM sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, float *alpha, float **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, double *alpha, double **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, std::complex *alpha, std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, int64_t *groupsize, + int64_t *ldb, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, std::complex *alpha, std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, int64_t *groupsize, + int64_t *ldb, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { throw unimplemented("blas", "imatcopy_batch", "for column_major layout"); } } // namespace column_major + namespace row_major { // Buffer APIs -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +template +inline void copy_batch(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, + int64_t stridey, int64_t batch_size) { + column_major::copy_batch(func, queue, n, x, incx, stridex, y, incy, stridey, batch_size); } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, int64_t stridex, sycl::buffer, 1> &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} +#define COPY_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, \ + int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, \ + int64_t batch_size) { \ + copy_batch(ROCBLAS_ROUTINE, queue, n, x, incx, stridex, y, incy, stridey, batch_size); \ + } -void copy_batch(sycl::queue &queue, int64_t n, sycl::buffer, 1> &x, - int64_t incx, int64_t stridex, sycl::buffer, 1> &y, - int64_t incy, int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} +COPY_STRIDED_BATCH_LAUNCHER(float, rocblas_scopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER(double, rocblas_dcopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_ccopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zcopy_strided_batched) -void axpy_batch(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, - int64_t stridex, sycl::buffer &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +#undef COPY_STRIDED_BATCH_LAUNCHER -void axpy_batch(sycl::queue &queue, int64_t n, double alpha, sycl::buffer &x, - int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, - int64_t stridey, int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +template +inline void axpy_batch(Func func, sycl::queue &queue, int64_t n, T alpha, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, + int64_t stridey, int64_t batch_size) { + column_major::axpy_batch(func, queue, n, alpha, x, incx, stridex, y, incy, stridey, batch_size); +} + +#define AXPY_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void axpy_batch(sycl::queue &queue, int64_t n, TYPE alpha, sycl::buffer &x, \ + int64_t incx, int64_t stridex, sycl::buffer &y, int64_t incy, \ + int64_t stridey, int64_t batch_size) { \ + axpy_batch(ROCBLAS_ROUTINE, queue, n, alpha, x, incx, stridex, y, incy, stridey, \ + batch_size); \ + } -void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, int64_t stridex, - sycl::buffer, 1> &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +AXPY_STRIDED_BATCH_LAUNCHER(float, rocblas_saxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER(double, rocblas_daxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_caxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zaxpy_strided_batched) -void axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - sycl::buffer, 1> &x, int64_t incx, int64_t stridex, - sycl::buffer, 1> &y, int64_t incy, int64_t stridey, - int64_t batch_size) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +#undef AXPY_STRIDED_BATCH_LAUNCHER -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, - int64_t incx, int64_t stride_x, float beta, sycl::buffer &y, int64_t incy, - int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} +template +inline void gemv_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + int64_t stridea, sycl::buffer, 1> &x, int64_t incx, + int64_t stridex, std::complex beta, sycl::buffer, 1> &y, + int64_t incy, int64_t stridey, int64_t batch_size) { + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + alpha = std::conj(alpha); + beta = std::conj(beta); + + if (m > 0) { + queue.submit( + [&](sycl::handler &cgh) { conj_vector(cgh, x, m, incx, stridex, batch_size); }); + + if (n > 0) { + queue.submit( + [&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy, stridey, batch_size); }); + } + } + } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &x, int64_t incx, int64_t stride_x, double beta, - sycl::buffer &y, int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} + column_major::gemv_batch(func, queue, new_trans, n, m, alpha, a, lda, stridea, x, incx, stridex, + beta, y, incy, stridey, batch_size); -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, - int64_t stride_x, std::complex beta, sycl::buffer, 1> &y, - int64_t incy, int64_t stride_y, int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit( + [&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy, stridey, batch_size); }); + } + } } -void gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &x, int64_t incx, - int64_t stride_x, std::complex beta, - sycl::buffer, 1> &y, int64_t incy, int64_t stride_y, - int64_t batch_size) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} +template +inline void gemv_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + T alpha, sycl::buffer &a, int64_t lda, int64_t stridea, + sycl::buffer &x, int64_t incx, int64_t stridex, T beta, + sycl::buffer &y, int64_t incy, int64_t stridey, int64_t batch_size) { + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::gemv_batch(func, queue, new_trans, n, m, alpha, a, lda, stridea, x, incx, stridex, + beta, y, incy, stridey, batch_size); +} + +#define GEMV_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void gemv_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &x, int64_t incx, int64_t stridex, TYPE beta, \ + sycl::buffer &y, int64_t incy, int64_t stridey, int64_t batch_size) { \ + gemv_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, stridea, x, incx, stridex, \ + beta, y, incy, stridey, batch_size); \ + } -void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer &a, int64_t lda, int64_t stride_a, sycl::buffer &x, - int64_t incx, int64_t stride_x, sycl::buffer &c, int64_t ldc, - int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); -} +GEMV_STRIDED_BATCH_LAUNCHER(float, rocblas_sgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER(double, rocblas_dgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgemv_strided_batched) -void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &x, int64_t incx, int64_t stride_x, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); -} +#undef GEMV_STRIDED_BATCH_LAUNCHER -void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); -} +template +inline void dgmm_batch(Func func, sycl::queue &queue, side left_right, int64_t m, int64_t n, + sycl::buffer &a, int64_t lda, int64_t stridea, sycl::buffer &x, + int64_t incx, int64_t stridex, sycl::buffer &c, int64_t ldc, + int64_t stridec, int64_t batch_size) { + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + + column_major::dgmm_batch(func, queue, new_side, n, m, a, lda, stridea, x, incx, stridex, c, ldc, + stridec, batch_size); +} + +#define DGMM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, \ + sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &x, int64_t incx, int64_t stridex, \ + sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { \ + dgmm_batch(ROCBLAS_ROUTINE, queue, left_right, m, n, a, lda, stridea, x, incx, stridex, c, \ + ldc, stridec, batch_size); \ + } -void dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &x, int64_t incx, int64_t stride_x, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); -} +DGMM_STRIDED_BATCH_LAUNCHER(float, rocblas_sdgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER(double, rocblas_ddgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cdgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zdgmm_strided_batched) -template -inline void gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, - T beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#undef DGMM_STRIDED_BATCH_LAUNCHER + +template +inline void gemm_batch_impl(sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, int64_t k, Ts alpha, sycl::buffer &a, int64_t lda, + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, + Ts beta, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { + auto new_transa = transb; + auto new_transb = transa; + + column_major::gemm_batch(queue, new_transa, new_transb, n, m, k, alpha, b, ldb, strideb, a, lda, + stridea, beta, c, ldc, stridec, batch_size); } -#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ +#undef GEMM_STRIDED_BATCH_LAUNCHER +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE alpha, sycl::buffer &a, int64_t lda, \ - int64_t stride_a, sycl::buffer &b, int64_t ldb, int64_t stride_b, \ - TYPE beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ int64_t batch_size) { \ - gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, \ - ldb, stride_b, beta, c, ldc, stride_c, batch_size); \ + gemm_batch_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, b, ldb, strideb, \ + beta, c, ldc, stridec, batch_size); \ } -GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgemm_strided_batched) + +GEMM_STRIDED_BATCH_LAUNCHER(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER(sycl::half, sycl::half, float, float) #undef GEMM_STRIDED_BATCH_LAUNCHER -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, float alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); -} +#define GEMM_STRIDED_BATCH_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + void gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, sycl::buffer &b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, double alpha, sycl::buffer &a, - int64_t lda, int64_t stride_a, sycl::buffer &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); -} +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER(std::int8_t, std::int8_t, std::int32_t, float) -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); -} +#undef GEMM_STRIDED_BATCH_LAUNCHER -void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, int64_t stride_a, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - int64_t batch_size) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); -} +template +inline void trsm_batch(Func func, sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, + sycl::buffer &a, int64_t lda, int64_t stridea, sycl::buffer &b, + int64_t ldb, int64_t strideb, int64_t batch_size) { + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::trsm_batch(func, queue, new_side, new_uplo, trans, unit_diag, n, m, alpha, a, lda, + stridea, b, ldb, strideb, batch_size); +} + +#define TRSM_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \ + diag unit_diag, int64_t m, int64_t n, TYPE alpha, sycl::buffer &a, \ + int64_t lda, int64_t stridea, sycl::buffer &b, int64_t ldb, \ + int64_t strideb, int64_t batch_size) { \ + trsm_batch(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, \ + a, lda, stridea, b, ldb, strideb, batch_size); \ + } -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, float beta, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} +TRSM_STRIDED_BATCH_LAUNCHER(float, rocblas_strsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER(double, rocblas_dtrsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_ctrsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_ztrsm_strided_batched) -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} +#undef TRSM_STRIDED_BATCH_LAUNCHER -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, sycl::buffer, 1> &c, - int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} +template +inline void syrk_batch(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, + int64_t k, T alpha, sycl::buffer &a, int64_t lda, int64_t stridea, + T beta, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; -void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + column_major::syrk_batch(func, queue, new_uplo, new_trans, n, k, alpha, a, lda, stridea, beta, + c, ldc, stridec, batch_size); } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +#define SYRK_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, \ + TYPE alpha, sycl::buffer &a, int64_t lda, int64_t stridea, TYPE beta, \ + sycl::buffer &c, int64_t ldc, int64_t stridec, int64_t batch_size) { \ + syrk_batch(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, stridea, beta, \ + c, ldc, stridec, batch_size); \ + } -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, int64_t stride_a, - sycl::buffer &b, int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +SYRK_STRIDED_BATCH_LAUNCHER(float, rocblas_ssyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER(double, rocblas_dsyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_csyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zsyrk_strided_batched) -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, sycl::buffer, 1> &b, int64_t ldb, - int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +#undef SYRK_STRIDED_BATCH_LAUNCHER -void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, int64_t stride_a, sycl::buffer, 1> &b, - int64_t ldb, int64_t stride_b, int64_t batch_size) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +template +inline void omatcopy_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + const T alpha, sycl::buffer &a, int64_t lda, int64_t stridea, + sycl::buffer &b, int64_t ldb, int64_t strideb, + int64_t batch_size) { + return column_major::omatcopy_batch(func, queue, trans, n, m, alpha, a, lda, stridea, b, ldb, + strideb, batch_size); +} + +#define OMATCOPY_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, \ + const TYPE alpha, sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &b, int64_t ldb, int64_t strideb, \ + int64_t batch_size) { \ + omatcopy_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, stridea, b, ldb, \ + strideb, batch_size); \ + } + +OMATCOPY_STRIDED_BATCH_LAUNCHER(float, rocblas_sgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER(double, rocblas_dgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgeam_strided_batched) + +#undef OMATCOPY_STRIDED_BATCH_LAUNCHER void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb, int64_t stride, @@ -1186,351 +1695,518 @@ void imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - float beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); -} - -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, int64_t stride_a, - double beta, sycl::buffer &b, int64_t ldb, int64_t stride_b, - sycl::buffer &c, int64_t ldc, int64_t stride_c, int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); -} +template +inline void omatadd_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, const T alpha, sycl::buffer &a, int64_t lda, + int64_t stridea, const T beta, sycl::buffer &b, int64_t ldb, + int64_t strideb, sycl::buffer &c, int64_t ldc, int64_t stridec, + int64_t batch_size) { + return column_major::omatadd_batch(func, queue, transa, transb, n, m, alpha, a, lda, stridea, + beta, b, ldb, strideb, c, ldc, stridec, batch_size); +} + +#define OMATADD_STRIDED_BATCH_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, const TYPE alpha, sycl::buffer &a, int64_t lda, \ + int64_t stridea, const TYPE beta, sycl::buffer &b, int64_t ldb, \ + int64_t strideb, sycl::buffer &c, int64_t ldc, int64_t stridec, \ + int64_t batch_size) { \ + omatadd_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, stridea, beta, \ + b, ldb, strideb, c, ldc, stridec, batch_size); \ + } -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); -} +OMATADD_STRIDED_BATCH_LAUNCHER(float, rocblas_sgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER(double, rocblas_dgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_cgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER(std::complex, rocblas_zgeam_strided_batched) -void omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, - int64_t lda, int64_t stride_a, std::complex beta, - sycl::buffer, 1> &b, int64_t ldb, int64_t stride_b, - sycl::buffer, 1> &c, int64_t ldc, int64_t stride_c, - int64_t batch_size) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); -} +#undef OMATADD_STRIDED_BATCH_LAUNCHER // USM APIs -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const float **x, int64_t *incx, float **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const double **x, int64_t *incx, double **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); +template +inline sycl::event copy_batch(Func func, sycl::queue &queue, int64_t *n, const T **x, int64_t *incx, + T **y, int64_t *incy, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + return column_major::copy_batch(func, queue, n, x, incx, y, incy, group_count, group_size, + dependencies); } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, int64_t *incx, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} +#define COPY_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event copy_batch(sycl::queue &queue, int64_t *n, const TYPE **x, int64_t *incx, \ + TYPE **y, int64_t *incy, int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return copy_batch(ROCBLAS_ROUTINE, queue, n, x, incx, y, incy, group_count, group_size, \ + dependencies); \ + } -sycl::event copy_batch(sycl::queue &queue, int64_t *n, const std::complex **x, - int64_t *incx, std::complex **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} +COPY_BATCH_LAUNCHER_USM(float, rocblas_scopy_batched) +COPY_BATCH_LAUNCHER_USM(double, rocblas_dcopy_batched) +COPY_BATCH_LAUNCHER_USM(std::complex, rocblas_ccopy_batched) +COPY_BATCH_LAUNCHER_USM(std::complex, rocblas_zcopy_batched) -sycl::event copy_batch(sycl::queue &queue, int64_t n, const float *x, int64_t incx, int64_t stridex, - float *y, int64_t incy, int64_t stridey, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} +#undef COPY_BATCH_LAUNCHER_USM -sycl::event copy_batch(sycl::queue &queue, int64_t n, const double *x, int64_t incx, - int64_t stridex, double *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} +template +inline sycl::event copy_batch(Func func, sycl::queue &queue, int64_t n, const T *x, int64_t incx, + int64_t stridex, T *y, int64_t incy, int64_t stridey, + int64_t batch_size, const std::vector &dependencies) { + return column_major::copy_batch(func, queue, n, x, incx, stridex, y, incy, stridey, batch_size, + dependencies); +} + +#define COPY_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event copy_batch(sycl::queue &queue, int64_t n, const TYPE *x, int64_t incx, \ + int64_t stridex, TYPE *y, int64_t incy, int64_t stridey, \ + int64_t batch_size, const std::vector &dependencies) { \ + return copy_batch(ROCBLAS_ROUTINE, queue, n, x, incx, stridex, y, incy, stridey, \ + batch_size, dependencies); \ + } -sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - int64_t stridex, std::complex *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} +COPY_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_scopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dcopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_ccopy_strided_batched) +COPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zcopy_strided_batched) -sycl::event copy_batch(sycl::queue &queue, int64_t n, const std::complex *x, int64_t incx, - int64_t stridex, std::complex *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "copy_batch", "for row_major layout"); -} +#undef COPY_STRIDED_BATCH_LAUNCHER_USM -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, float *alpha, const float **x, int64_t *incx, - float **y, int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); +template +inline sycl::event axpy_batch(Func func, sycl::queue &queue, int64_t *n, T *alpha, const T **x, + int64_t *incx, T **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + return column_major::axpy_batch(func, queue, n, alpha, x, incx, y, incy, group_count, + group_size, dependencies); } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, double *alpha, const double **x, - int64_t *incx, double **y, int64_t *incy, int64_t group_count, - int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +#define AXPY_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event axpy_batch(sycl::queue &queue, int64_t *n, TYPE *alpha, const TYPE **x, \ + int64_t *incx, TYPE **y, int64_t *incy, int64_t group_count, \ + int64_t *group_size, const std::vector &dependencies) { \ + return axpy_batch(ROCBLAS_ROUTINE, queue, n, alpha, x, incx, y, incy, group_count, \ + group_size, dependencies); \ + } -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, - const std::complex **x, int64_t *incx, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +AXPY_BATCH_LAUNCHER_USM(float, rocblas_saxpy_batched) +AXPY_BATCH_LAUNCHER_USM(double, rocblas_daxpy_batched) +AXPY_BATCH_LAUNCHER_USM(std::complex, rocblas_caxpy_batched) +AXPY_BATCH_LAUNCHER_USM(std::complex, rocblas_zaxpy_batched) -sycl::event axpy_batch(sycl::queue &queue, int64_t *n, std::complex *alpha, - const std::complex **x, int64_t *incx, std::complex **y, - int64_t *incy, int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +#undef AXPY_BATCH_LAUNCHER_USM -sycl::event axpy_batch(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, - int64_t stridex, float *y, int64_t incy, int64_t stridey, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +template +inline sycl::event axpy_batch(Func func, sycl::queue &queue, int64_t n, T alpha, const T *x, + int64_t incx, int64_t stridex, T *y, int64_t incy, int64_t stridey, + int64_t batch_size, const std::vector &dependencies) { + return column_major::axpy_batch(func, queue, n, alpha, x, incx, stridex, y, incy, stridey, + batch_size, dependencies); +} + +#define AXPY_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event axpy_batch(sycl::queue &queue, int64_t n, TYPE alpha, const TYPE *x, int64_t incx, \ + int64_t stridex, TYPE *y, int64_t incy, int64_t stridey, \ + int64_t batch_size, const std::vector &dependencies) { \ + return axpy_batch(ROCBLAS_ROUTINE, queue, n, alpha, x, incx, stridex, y, incy, stridey, \ + batch_size, dependencies); \ + } -sycl::event axpy_batch(sycl::queue &queue, int64_t n, double alpha, const double *x, int64_t incx, - int64_t stridex, double *y, int64_t incy, int64_t stridey, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +AXPY_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_saxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_daxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_caxpy_strided_batched) +AXPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zaxpy_strided_batched) -sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, int64_t stridex, - std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +#undef AXPY_BATCH_LAUNCHER_USM -sycl::event axpy_batch(sycl::queue &queue, int64_t n, std::complex alpha, - const std::complex *x, int64_t incx, int64_t stridex, - std::complex *y, int64_t incy, int64_t stridey, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "axpy_batch", "for row_major layout"); -} +template +inline sycl::event gemv_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + int64_t stridea, const std::complex *x, int64_t incx, + int64_t stridex, std::complex beta, std::complex *y, + int64_t incy, int64_t stridey, int64_t batch_size, + const std::vector &dependencies) { + sycl::event done; -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float beta, float *y, int64_t incy, int64_t stride_y, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, const double *x, - int64_t incx, int64_t stride_x, double beta, double *y, int64_t incy, - int64_t stride_y, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} + if (trans == oneapi::mkl::transpose::conjtrans) { + alpha = std::conj(alpha); + beta = std::conj(beta); -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, int64_t incx, - int64_t stride_x, std::complex beta, std::complex *y, - int64_t incy, int64_t stride_y, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} + if (m > 0) { + done = queue.submit([&](sycl::handler &cgh) { + conj_vector(cgh, (std::complex *)x, m, incx, stridex, batch_size); + }); -sycl::event gemv_batch(sycl::queue &queue, transpose transa, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, const std::complex *x, int64_t incx, - int64_t stride_x, std::complex beta, std::complex *y, - int64_t incy, int64_t stride_y, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} + if (n > 0) { + done = queue.submit( + [&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy, stridey, batch_size); }); + } + } + } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, float *alpha, - const float **a, int64_t *lda, const float **x, int64_t *incx, float *beta, - float **y, int64_t *incy, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} + done.wait_and_throw(); -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, double *alpha, - const double **a, int64_t *lda, const double **x, int64_t *incx, - double *beta, double **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} + done = column_major::gemv_batch(func, queue, new_trans, n, m, alpha, a, lda, stridea, x, incx, + stridex, beta, y, incy, stridey, batch_size, dependencies); -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); -} + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, y, n, incy, stridey, batch_size); + }); + } + } -sycl::event gemv_batch(sycl::queue &queue, transpose *transa, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - const std::complex **x, int64_t *incx, std::complex *beta, - std::complex **y, int64_t *incy, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "gemv_batch", "for row_major layout"); + return done; } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const float *a, - int64_t lda, int64_t stride_a, const float *x, int64_t incx, - int64_t stride_x, float *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); -} +template +inline sycl::event gemv_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + T alpha, const T *a, int64_t lda, int64_t stridea, const T *x, + int64_t incx, int64_t stridex, T beta, T *y, int64_t incy, + int64_t stridey, int64_t batch_size, + const std::vector &dependencies) { + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, const double *a, - int64_t lda, int64_t stride_a, const double *x, int64_t incx, - int64_t stride_x, double *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + return column_major::gemv_batch(func, queue, new_trans, n, m, alpha, a, lda, stridea, x, incx, + stridex, beta, y, incy, stridey, batch_size, dependencies); } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - const std::complex *a, int64_t lda, int64_t stride_a, - const std::complex *x, int64_t incx, int64_t stride_x, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); -} +#define GEMV_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event gemv_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, int64_t stridea, const TYPE *x, \ + int64_t incx, int64_t stridex, TYPE beta, TYPE *y, int64_t incy, \ + int64_t stridey, int64_t batch_size, \ + const std::vector &dependencies) { \ + return gemv_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, stridea, x, incx, \ + stridex, beta, y, incy, stridey, batch_size, dependencies); \ + } -sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, - const std::complex *a, int64_t lda, int64_t stride_a, - const std::complex *x, int64_t incx, int64_t stride_x, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); -} +GEMV_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemv_strided_batched) +GEMV_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemv_strided_batched) -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const float **a, int64_t *lda, const float **x, int64_t *incx, float **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); -} +#undef GEMV_STRIDED_BATCH_LAUNCHER_USM -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const double **a, int64_t *lda, const double **x, int64_t *incx, double **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); -} +template +inline sycl::event gemv_batch(Func func, sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, std::complex *alpha, const std::complex **a, + int64_t *lda, const std::complex **x, int64_t *incx, + std::complex *beta, std::complex **y, int64_t *incy, + int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + sycl::event done; + + int64_t stride = 0; + for (int64_t i = 0; i < group_count; i++) { + if (trans[i] == oneapi::mkl::transpose::conjtrans) { + alpha[i] = std::conj(alpha[i]); + beta[i] = std::conj(beta[i]); + + if (m[i] > 0) { + done = queue.submit([&](sycl::handler &cgh) { + conj_vector(cgh, (std::complex **)x, m[i], incx[i], stride, group_size[i]); + }); + + if (n[i] > 0) { + done = queue.submit([&](sycl::handler &cgh) { + conj_vector(cgh, y, n[i], incy[i], stride, group_size[i]); + }); + } + } + } + stride += group_size[i]; + } + + done.wait_and_throw(); + + auto tmp_trans = std::vector{ (std::size_t)group_count }; + for (int64_t i = 0; i < group_count; i++) { + const auto new_trans = trans[i] == oneapi::mkl::transpose::nontrans + ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + tmp_trans[i] = trans[i]; + trans[i] = new_trans; + } + done = column_major::gemv_batch(func, queue, trans, n, m, alpha, a, lda, x, incx, beta, y, incy, + group_count, group_size, dependencies); + done.wait_and_throw(); + for (int64_t i = 0; i < group_count; i++) { + trans[i] = tmp_trans[i]; + } + + stride = 0; + for (int64_t i = 0; i < group_count; i++) { + if (trans[i] == oneapi::mkl::transpose::conjtrans) { + if (n[i] > 0) { + done = queue.submit([&](sycl::handler &cgh) { + conj_vector(cgh, y, n[i], incy[i], stride, group_size[i]); + }); + } + } + stride += group_size[i]; + } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const std::complex **a, int64_t *lda, const std::complex **x, - int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); + return done; } -sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, - const std::complex **a, int64_t *lda, const std::complex **x, - int64_t *incx, std::complex **c, int64_t *ldc, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "dgmm_batch", "for row_major layout"); +template +inline sycl::event gemv_batch(Func func, sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, T *alpha, const T **a, int64_t *lda, const T **x, + int64_t *incx, T *beta, T **y, int64_t *incy, int64_t group_count, + int64_t *group_size, const std::vector &dependencies) { + auto tmp_trans = std::vector{ static_cast(group_count) }; + + for (int64_t i = 0; i < group_count; i++) { + const auto new_trans = trans[i] == oneapi::mkl::transpose::nontrans + ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + tmp_trans[i] = trans[i]; + trans[i] = new_trans; + } + auto done = column_major::gemv_batch(func, queue, trans, n, m, alpha, a, lda, x, incx, beta, y, + incy, group_count, group_size, dependencies); + done.wait_and_throw(); + for (int64_t i = 0; i < group_count; i++) { + trans[i] = tmp_trans[i]; + } + + return done; } +#define GEMV_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event gemv_batch( \ + sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, TYPE *alpha, const TYPE **a, \ + int64_t *lda, const TYPE **x, int64_t *incx, TYPE *beta, TYPE **y, int64_t *incy, \ + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { \ + return gemv_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, x, incx, beta, y, \ + incy, group_count, group_size, dependencies); \ + } + +GEMV_BATCH_LAUNCHER_USM(float, rocblas_sgemv_batched) +GEMV_BATCH_LAUNCHER_USM(double, rocblas_dgemv_batched) +GEMV_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemv_batched) +GEMV_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemv_batched) + +#undef GEMV_BATCH_LAUNCHER_USM + +template +inline sycl::event dgmm_batch(Func func, sycl::queue &queue, side left_right, int64_t m, int64_t n, + const T *a, int64_t lda, int64_t stridea, const T *x, int64_t incx, + int64_t stridex, T *c, int64_t ldc, int64_t stridec, + int64_t batch_size, const std::vector &dependencies) { + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + + return column_major::dgmm_batch(func, queue, new_side, n, m, a, lda, stridea, x, incx, stridex, + c, ldc, stridec, batch_size, dependencies); +} + +#define DGMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event dgmm_batch(sycl::queue &queue, side left_right, int64_t m, int64_t n, \ + const TYPE *a, int64_t lda, int64_t stridea, const TYPE *x, \ + int64_t incx, int64_t stridex, TYPE *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return dgmm_batch(ROCBLAS_ROUTINE, queue, left_right, m, n, a, lda, stridea, x, incx, \ + stridex, c, ldc, stridec, batch_size, dependencies); \ + } + +DGMM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sdgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_ddgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cdgmm_strided_batched) +DGMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zdgmm_strided_batched) + +#undef DGMM_STRIDED_BATCH_LAUNCHER_USM + template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, - int64_t m, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, - int64_t stride_a, const T *b, int64_t ldb, int64_t stride_b, T beta, - T *c, int64_t ldc, int64_t stride_c, int64_t batch_size, +inline sycl::event dgmm_batch(Func func, sycl::queue &queue, side *left_right, int64_t *m, + int64_t *n, const T **a, int64_t *lda, const T **x, int64_t *incx, + T **c, int64_t *ldc, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); + for (int64_t i = 0; i < group_count; i++) { + const auto new_side = left_right[i] == oneapi::mkl::side::left ? oneapi::mkl::side::right + : oneapi::mkl::side::left; + left_right[i] = new_side; + } + + return column_major::dgmm_batch(func, queue, left_right, n, m, a, lda, x, incx, c, ldc, + group_count, group_size, dependencies); } -#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ - int64_t n, int64_t k, TYPE alpha, const TYPE *a, int64_t lda, \ - int64_t stride_a, const TYPE *b, int64_t ldb, int64_t stride_b, \ - TYPE beta, TYPE *c, int64_t ldc, int64_t stride_c, int64_t batch_size, \ +#define DGMM_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event dgmm_batch(sycl::queue &queue, side *left_right, int64_t *m, int64_t *n, \ + const TYPE **a, int64_t *lda, const TYPE **x, int64_t *incx, TYPE **c, \ + int64_t *ldc, int64_t group_count, int64_t *group_size, \ const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, \ - stride_a, b, ldb, stride_b, beta, c, ldc, stride_c, batch_size, \ - dependencies); \ + return dgmm_batch(ROCBLAS_ROUTINE, queue, left_right, m, n, a, lda, x, incx, c, ldc, \ + group_count, group_size, dependencies); \ + } + +DGMM_BATCH_LAUNCHER_USM(float, rocblas_sdgmm_batched) +DGMM_BATCH_LAUNCHER_USM(double, rocblas_ddgmm_batched) +DGMM_BATCH_LAUNCHER_USM(std::complex, rocblas_cdgmm_batched) +DGMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zdgmm_batched) + +#undef DGMM_BATCH_LAUNCHER + +template +inline sycl::event gemm_batch_strided_usm_impl(sycl::queue &queue, transpose transa, + transpose transb, int64_t m, int64_t n, int64_t k, + Ts alpha, const Ta *a, int64_t lda, int64_t stridea, + const Tb *b, int64_t ldb, int64_t strideb, Ts beta, + Tc *c, int64_t ldc, int64_t stridec, + int64_t batch_size, + const std::vector &dependencies) { + auto new_transa = transb; + auto new_transb = transa; + + return column_major::gemm_batch(queue, new_transa, new_transb, n, m, k, alpha, b, ldb, strideb, + a, lda, stridea, beta, c, ldc, stridec, batch_size, + dependencies); +} + +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return gemm_batch_strided_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, stridea, \ + b, ldb, strideb, beta, c, ldc, stridec, batch_size, \ + dependencies); \ } -GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_strided_batched) -GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_strided_batched) +GEMM_STRIDED_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_STRIDED_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) #undef GEMM_STRIDED_BATCH_LAUNCHER_USM -template -inline sycl::event gemm_batch(Func func, sycl::queue &queue, transpose *transa, transpose *transb, - int64_t *m, int64_t *n, int64_t *k, T *alpha, const T **a, - int64_t *lda, const T **b, int64_t *ldb, T *beta, T **c, int64_t *ldc, - int64_t group_count, int64_t *group_size, - const std::vector &dependencies) { - throw unimplemented("blas", "gemm_batch", "for row_major layout"); +#define GEMM_STRIDED_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, \ + int64_t stridea, const TYPE_B *b, int64_t ldb, int64_t strideb, \ + TYPE_S beta, TYPE_C *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } + +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_STRIDED_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) + +#undef GEMM_STRIDED_BATCH_LAUNCHER_USM + +template +inline sycl::event gemm_batch_usm_impl(sycl::queue &queue, transpose *transa, transpose *transb, + int64_t *m, int64_t *n, int64_t *k, Ts *alpha, const Ta **a, + int64_t *lda, const Tb **b, int64_t *ldb, Ts *beta, Tc **c, + int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + for (int64_t i = 0; i < group_count; i++) { + std::swap(transa[i], transb[i]); + } + + return column_major::gemm_batch(queue, transa, transb, n, m, k, alpha, b, ldb, a, lda, beta, c, + ldc, group_count, group_size, dependencies); } -#define GEMM_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ - sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ - int64_t *n, int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, \ - const TYPE **b, int64_t *ldb, TYPE *beta, TYPE **c, int64_t *ldc, \ - int64_t group_count, int64_t *group_size, \ - const std::vector &dependencies) { \ - return gemm_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, \ - beta, c, ldc, group_count, group_size, dependencies); \ +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return gemm_batch_usm_impl(queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, \ + ldc, group_count, group_size, dependencies); \ } -GEMM_BATCH_LAUNCHER_USM(sycl::half, rocblas_hgemm_batched) -GEMM_BATCH_LAUNCHER_USM(float, rocblas_sgemm_batched) -GEMM_BATCH_LAUNCHER_USM(double, rocblas_dgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_cgemm_batched) -GEMM_BATCH_LAUNCHER_USM(std::complex, rocblas_zgemm_batched) +GEMM_BATCH_LAUNCHER_USM(float, float, float, float) +GEMM_BATCH_LAUNCHER_USM(double, double, double, double) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(std::complex, std::complex, std::complex, + std::complex) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half) +GEMM_BATCH_LAUNCHER_USM(sycl::half, sycl::half, float, float) #undef GEMM_BATCH_LAUNCHER_USM -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, float alpha, const float *a, - int64_t lda, int64_t stride_a, float *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); -} +#define GEMM_BATCH_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S) \ + sycl::event gemm_batch(sycl::queue &queue, transpose *transa, transpose *transb, int64_t *m, \ + int64_t *n, int64_t *k, TYPE_S *alpha, const TYPE_A **a, int64_t *lda, \ + const TYPE_B **b, int64_t *ldb, TYPE_S *beta, TYPE_C **c, int64_t *ldc, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + throw unimplemented("blas", "gemm_batch", \ + std::string("for dtype unimplemented dtype combination <") + \ + dtype_string() + "," + dtype_string() + "," + \ + dtype_string() + "," + dtype_string() + ">"); \ + } -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, double alpha, const double *a, - int64_t lda, int64_t stride_a, double *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); -} +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, float, float) +GEMM_BATCH_LAUNCHER_USM(std::int8_t, std::int8_t, std::int32_t, float) -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, int64_t stride_a, - std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); -} +#undef GEMM_BATCH_LAUNCHER_USM -sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, - diag unit_diag, int64_t m, int64_t n, std::complex alpha, - const std::complex *a, int64_t lda, int64_t stride_a, - std::complex *b, int64_t ldb, int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); +template +inline sycl::event trsm_batch(Func func, sycl::queue &queue, side left_right, uplo upper_lower, + transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, + const T *a, int64_t lda, int64_t stridea, T *b, int64_t ldb, + int64_t strideb, int64_t batch_size, + const std::vector &dependencies) { + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::trsm_batch(func, queue, new_side, new_uplo, trans, unit_diag, n, m, alpha, + a, lda, stridea, b, ldb, strideb, batch_size, dependencies); } +#define TRSM_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event trsm_batch(sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, \ + diag unit_diag, int64_t m, int64_t n, TYPE alpha, const TYPE *a, \ + int64_t lda, int64_t stridea, TYPE *b, int64_t ldb, int64_t strideb, \ + int64_t batch_size, const std::vector &dependencies) { \ + return trsm_batch(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, \ + alpha, a, lda, stridea, b, ldb, strideb, batch_size, dependencies); \ + } + +TRSM_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_strsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dtrsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_ctrsm_strided_batched) +TRSM_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_ztrsm_strided_batched) + +#undef TRSM_STRIDED_BATCH_LAUNCHER_USM + template inline sycl::event trsm_batch(Func func, sycl::queue &queue, side *left_right, uplo *upper_lower, transpose *trans, diag *unit_diag, int64_t *m, int64_t *n, T *alpha, const T **a, int64_t *lda, T **b, int64_t *ldb, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { - throw unimplemented("blas", "trsm_batch", "for row_major layout"); + for (int64_t i = 0; i < group_count; i++) { + const auto new_side = left_right[i] == oneapi::mkl::side::left ? oneapi::mkl::side::right + : oneapi::mkl::side::left; + left_right[i] = new_side; + + const auto new_uplo = upper_lower[i] == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + upper_lower[i] = new_uplo; + } + + return column_major::trsm_batch(func, queue, left_right, upper_lower, trans, unit_diag, n, m, + alpha, a, lda, b, ldb, group_count, group_size, dependencies); } #define TRSM_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1550,93 +2226,98 @@ TRSM_BATCH_LAUNCHER_USM(std::complex, rocblas_ztrsm_batched) #undef TRSM_BATCH_LAUNCHER_USM -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, float *alpha, const float **a, int64_t *lda, float *beta, - float **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} - -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, double *alpha, const double **a, int64_t *lda, double *beta, - double **c, int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} +template +inline sycl::event syrk_batch(Func func, sycl::queue &queue, uplo *upper_lower, transpose *trans, + int64_t *n, int64_t *k, T *alpha, const T **a, int64_t *lda, T *beta, + T **c, int64_t *ldc, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + for (int64_t i = 0; i < group_count; i++) { + const auto new_uplo = upper_lower[i] == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + upper_lower[i] = new_uplo; + + const auto new_trans = trans[i] == oneapi::mkl::transpose::nontrans + ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + trans[i] = new_trans; + } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex *beta, std::complex **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); + return column_major::syrk_batch(func, queue, upper_lower, trans, n, k, alpha, a, lda, beta, c, + ldc, group_count, group_size, dependencies); } -sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, - int64_t *k, std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex *beta, std::complex **c, - int64_t *ldc, int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} +#define SYRK_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event syrk_batch(sycl::queue &queue, uplo *upper_lower, transpose *trans, int64_t *n, \ + int64_t *k, TYPE *alpha, const TYPE **a, int64_t *lda, TYPE *beta, \ + TYPE **c, int64_t *ldc, int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return syrk_batch(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, beta, \ + c, ldc, group_count, group_size, dependencies); \ + } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - float alpha, const float *a, int64_t lda, int64_t stride_a, float beta, - float *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} +SYRK_BATCH_LAUNCHER_USM(float, rocblas_ssyrk_batched) +SYRK_BATCH_LAUNCHER_USM(double, rocblas_dsyrk_batched) +SYRK_BATCH_LAUNCHER_USM(std::complex, rocblas_csyrk_batched) +SYRK_BATCH_LAUNCHER_USM(std::complex, rocblas_zsyrk_batched) -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - double alpha, const double *a, int64_t lda, int64_t stride_a, double beta, - double *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} +#undef SYRK_BATCH_LAUNCHER_USM -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} +template +inline sycl::event syrk_batch(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, + int64_t n, int64_t k, const T alpha, const T *a, int64_t lda, + int64_t stridea, const T beta, T *c, int64_t ldc, int64_t stridec, + int64_t batch_size, const std::vector &dependencies) { + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::syrk_batch(func, queue, new_uplo, new_trans, n, k, alpha, a, lda, stridea, + beta, c, ldc, stridec, batch_size, dependencies); +} + +#define SYRK_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, \ + int64_t k, const TYPE alpha, const TYPE *a, int64_t lda, \ + int64_t stridea, const TYPE beta, TYPE *c, int64_t ldc, \ + int64_t stridec, int64_t batch_size, \ + const std::vector &dependencies) { \ + return syrk_batch(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, \ + stridea, beta, c, ldc, stridec, batch_size, dependencies); \ + } -sycl::event syrk_batch(sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex beta, std::complex *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "syrk_batch", "for row_major layout"); -} +SYRK_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_ssyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dsyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_csyrk_strided_batched) +SYRK_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zsyrk_strided_batched) -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, int64_t stride_a, float *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +#undef SYRK_STRIDED_BATCH_LAUNCHER_USM -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, int64_t stride_a, double *b, int64_t ldb, - int64_t stride_b, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +template +inline sycl::event omatcopy_batch(Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, const T alpha, const T *a, int64_t lda, + int64_t stridea, T *b, int64_t ldb, int64_t strideb, + int64_t batch_size, + const std::vector &dependencies) { + return column_major::omatcopy_batch(func, queue, trans, n, m, alpha, a, lda, stridea, b, ldb, + strideb, batch_size, dependencies); +} + +#define OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, \ + const TYPE alpha, const TYPE *a, int64_t lda, int64_t stridea, \ + TYPE *b, int64_t ldb, int64_t strideb, int64_t batch_size, \ + const std::vector &dependencies) { \ + return omatcopy_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, stridea, b, ldb, \ + strideb, batch_size, dependencies); \ + } -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgeam_strided_batched) +OMATCOPY_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgeam_strided_batched) -sycl::event omatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - int64_t stride_a, std::complex *b, int64_t ldb, int64_t stride_b, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +#undef OMATCOPY_STRIDED_BATCH_LAUNCHER_USM sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, int64_t stride, int64_t batch_size, @@ -1664,93 +2345,83 @@ sycl::event imatcopy_batch(sycl::queue &queue, transpose trans, int64_t m, int64 throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, float alpha, const float *a, int64_t lda, int64_t stride_a, - float beta, const float *b, int64_t ldb, int64_t stride_b, float *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); -} - -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, double alpha, const double *a, int64_t lda, int64_t stride_a, - double beta, const double *b, int64_t ldb, int64_t stride_b, double *c, - int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); -} - -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, int64_t stride_b, - std::complex *c, int64_t ldc, int64_t stride_c, int64_t batch_size, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); -} +template +inline sycl::event omatadd_batch(Func func, sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, const T alpha, const T *a, int64_t lda, + int64_t stridea, const T beta, const T *b, int64_t ldb, + int64_t strideb, T *c, int64_t ldc, int64_t stridec, + int64_t batch_size, const std::vector &dependencies) { + return column_major::omatadd_batch(func, queue, transa, transb, n, m, alpha, a, lda, stridea, + beta, b, ldb, strideb, c, ldc, stridec, batch_size, + dependencies); +} + +#define OMATADD_STRIDED_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, const TYPE alpha, const TYPE *a, int64_t lda, \ + int64_t stridea, const TYPE beta, const TYPE *b, int64_t ldb, \ + int64_t strideb, TYPE *c, int64_t ldc, int64_t stridec, \ + int64_t batch_size, const std::vector &dependencies) { \ + return omatadd_batch(ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, stridea, \ + beta, b, ldb, strideb, c, ldc, stridec, batch_size, dependencies); \ + } -sycl::event omatadd_batch(sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, std::complex alpha, const std::complex *a, - int64_t lda, int64_t stride_a, std::complex beta, - const std::complex *b, int64_t ldb, int64_t stride_b, - std::complex *c, int64_t ldc, int64_t stride_c, - int64_t batch_size, const std::vector &dependencies) { - throw unimplemented("blas", "omatadd_batch", "for row_major layout"); -} +OMATADD_STRIDED_BATCH_LAUNCHER_USM(float, rocblas_sgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER_USM(double, rocblas_dgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_cgeam_strided_batched) +OMATADD_STRIDED_BATCH_LAUNCHER_USM(std::complex, rocblas_zgeam_strided_batched) -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - float *alpha, const float **a, int64_t *lda, float **b, int64_t *ldb, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +#undef OMATADD_STRIDED_BATCH_LAUNCHER_USM -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - double *alpha, const double **a, int64_t *lda, double **b, int64_t *ldb, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +template +inline sycl::event omatcopy_batch(Func func, sycl::queue &queue, transpose *trans, int64_t *m, + int64_t *n, T *alpha, const T **a, int64_t *lda, T **b, + int64_t *ldb, int64_t group_count, int64_t *group_size, + const std::vector &dependencies) { + return column_major::omatcopy_batch(func, queue, trans, n, m, alpha, a, lda, b, ldb, + group_count, group_size, dependencies); +} + +#define OMATCOPY_BATCH_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, \ + TYPE *alpha, const TYPE **a, int64_t *lda, TYPE **b, int64_t *ldb, \ + int64_t group_count, int64_t *group_size, \ + const std::vector &dependencies) { \ + return omatcopy_batch(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb, \ + group_count, group_size, dependencies); \ + } -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, int64_t *lda, - std::complex **b, int64_t *ldb, int64_t group_count, - int64_t *groupsize, const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +OMATCOPY_BATCH_LAUNCHER_USM(float, rocblas_sgeam_batched) +OMATCOPY_BATCH_LAUNCHER_USM(double, rocblas_dgeam_batched) +OMATCOPY_BATCH_LAUNCHER_USM(std::complex, rocblas_cgeam_batched) +OMATCOPY_BATCH_LAUNCHER_USM(std::complex, rocblas_zgeam_batched) -sycl::event omatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, - std::complex *alpha, const std::complex **a, - int64_t *lda, std::complex **b, int64_t *ldb, - int64_t group_count, int64_t *groupsize, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy_batch", "for row_major layout"); -} +#undef OMATCOPY_BATCH_LAUNCHER_USM sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, float *alpha, float **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, double *alpha, double **ab, int64_t *lda, int64_t *ldb, - int64_t group_count, int64_t *groupsize, + int64_t group_count, int64_t *group_size, const std::vector &dependencies) { throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, std::complex *alpha, std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, int64_t *groupsize, + int64_t *ldb, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } sycl::event imatcopy_batch(sycl::queue &queue, transpose *trans, int64_t *m, int64_t *n, std::complex *alpha, std::complex **ab, int64_t *lda, - int64_t *ldb, int64_t group_count, int64_t *groupsize, + int64_t *ldb, int64_t group_count, int64_t *group_size, const std::vector &dependencies) { throw unimplemented("blas", "imatcopy_batch", "for row_major layout"); } diff --git a/src/blas/backends/rocblas/rocblas_extensions.cpp b/src/blas/backends/rocblas/rocblas_extensions.cpp index 315f9ce30..a1fd1df1c 100644 --- a/src/blas/backends/rocblas/rocblas_extensions.cpp +++ b/src/blas/backends/rocblas/rocblas_extensions.cpp @@ -18,6 +18,7 @@ * limitations under the License. * **************************************************************************/ + #include "rocblas_helper.hpp" #include "rocblas_task.hpp" @@ -88,27 +89,68 @@ void gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose tra throw unimplemented("blas", "gemmt", "for column_major layout"); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} - -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +template +inline void omatcopy(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + const T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb); + + const T beta = 0; + const int64_t new_m = trans == oneapi::mkl::transpose::nontrans ? m : n; + const int64_t new_n = trans == oneapi::mkl::transpose::nontrans ? n : m; + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), new_m, new_n, + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + lda, b_, ldb); + }); + }); +} + +#define OMATCOPY_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, const TYPE alpha, \ + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { \ + omatcopy(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb); \ + } + +OMATCOPY_LAUNCHER(float, rocblas_sgeam) +OMATCOPY_LAUNCHER(double, rocblas_dgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_zgeam) + +#undef OMATCOPY_LAUNCHER + +template +void omatcopy2(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb) { + throw unimplemented("blas", "omatcopy2", ""); +} + +#define OMATCOPY2_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &b, int64_t ldb, int64_t strideb) { \ + omatcopy2(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, stridea, lda, \ + b, ldb, strideb); \ + } + +OMATCOPY2_LAUNCHER(float, "unimplemented") +OMATCOPY2_LAUNCHER(double, "unimplemented") +OMATCOPY2_LAUNCHER(std::complex, "unimplemented") +OMATCOPY2_LAUNCHER(std::complex, "unimplemented") + +#undef OMATCOPY2_LAUNCHER void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { @@ -130,31 +172,45 @@ void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co throw unimplemented("blas", "imatcopy", "for column_major layout"); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, float beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, double beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +template +inline void omatadd(Func func, sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, const T alpha, sycl::buffer &a, int64_t lda, const T beta, + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + + queue.submit([&](sycl::handler &cgh) { + auto a_acc = a.template get_access(cgh); + auto b_acc = b.template get_access(cgh); + auto c_acc = c.template get_access(cgh); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + auto a_ = sc.get_mem(a_acc); + auto b_ = sc.get_mem(b_acc); + auto c_ = sc.get_mem(c_acc); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_, + lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); +} + +#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + const TYPE alpha, sycl::buffer &a, int64_t lda, const TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, b, ldb, c, \ + ldc); \ + } + +OMATADD_LAUNCHER(float, rocblas_sgeam) +OMATADD_LAUNCHER(double, rocblas_dgeam) +OMATADD_LAUNCHER(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER(std::complex, rocblas_zgeam) + +#undef OMATADD_LAUNCHER // USM APIs @@ -220,32 +276,71 @@ sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transp throw unimplemented("blas", "gemmt", "for column_major layout"); } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, float *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} +template +inline sycl::event omatcopy(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + const T alpha, const T *a, int64_t lda, T *b, int64_t ldb, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb); -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, double *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); -} + const T beta = 0; + const int64_t new_m = trans == oneapi::mkl::transpose::nontrans ? m : n; + const int64_t new_n = trans == oneapi::mkl::transpose::nontrans ? n : m; -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(trans), + get_rocblas_operation(trans), new_m, new_n, + (rocDataType *)&alpha, a_, lda, (rocDataType *)&beta, nullptr, + lda, b_, ldb); + }); + }); + + return done; } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for column_major layout"); +#define OMATCOPY_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, \ + const TYPE alpha, const TYPE *a, int64_t lda, TYPE *b, int64_t ldb, \ + const std::vector &dependencies) { \ + return omatcopy(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb, dependencies); \ + } + +OMATCOPY_LAUNCHER_USM(float, rocblas_sgeam) +OMATCOPY_LAUNCHER_USM(double, rocblas_dgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_zgeam) + +#undef OMATCOPY_LAUNCHER_USM + +template +sycl::event omatcopy2(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, int64_t stridea, T *b, + int64_t ldb, int64_t strideb, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy2", ""); } +#define OMATCOPY2_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, int64_t stridea, TYPE *b, int64_t ldb, \ + int64_t strideb, const std::vector &dependencies) { \ + return omatcopy2(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, stridea, \ + lda, b, ldb, strideb, dependencies); \ + } + +OMATCOPY2_LAUNCHER_USM(float, "unimplemented") +OMATCOPY2_LAUNCHER_USM(double, "unimplemented") +OMATCOPY2_LAUNCHER_USM(std::complex, "unimplemented") +OMATCOPY2_LAUNCHER_USM(std::complex, "unimplemented") + +#undef OMATCOPY2_LAUNCHER_USM + sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, const std::vector &dependencies) { @@ -270,37 +365,50 @@ sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy", "for column_major layout"); } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, float beta, const float *b, - int64_t ldb, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, double beta, const double *b, - int64_t ldb, double *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} - -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for column_major layout"); -} +template +inline sycl::event omatadd(Func func, sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, const T alpha, const T *a, int64_t lda, + const T beta, const T *b, int64_t ldb, T *c, int64_t ldc, + const std::vector &dependencies) { + using rocDataType = typename RocEquivalentType::Type; + overflow_check(m, n, lda, ldb, ldc); + + auto done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { + auto handle = sc.get_handle(queue); + + auto a_ = reinterpret_cast(a); + auto b_ = reinterpret_cast(b); + auto c_ = reinterpret_cast(c); + rocblas_status err; + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), + get_rocblas_operation(transb), m, n, (rocDataType *)&alpha, a_, + lda, (rocDataType *)&beta, b_, ldb, c_, ldc); + }); + }); + + return done; +} + +#define OMATADD_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, const TYPE alpha, const TYPE *a, int64_t lda, const TYPE beta, \ + const TYPE *b, int64_t ldb, TYPE *c, int64_t ldc, \ + const std::vector &dependencies) { \ + return omatadd(ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, b, ldb, \ + c, ldc, dependencies); \ + } + +OMATADD_LAUNCHER_USM(float, rocblas_sgeam) +OMATADD_LAUNCHER_USM(double, rocblas_dgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_zgeam) + +#undef OMATADD_LAUNCHER_USM } // namespace column_major + namespace row_major { // Buffer APIs @@ -361,28 +469,48 @@ void gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transpose tra throw unimplemented("blas", "gemmt", "for row_major layout"); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); +template +inline void omatcopy(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + const T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, + int64_t ldb) { + column_major::omatcopy(func, queue, trans, n, m, alpha, a, lda, b, ldb); } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} +#define OMATCOPY_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, const TYPE alpha, \ + sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { \ + omatcopy(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb); \ + } -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} +OMATCOPY_LAUNCHER(float, rocblas_sgeam) +OMATCOPY_LAUNCHER(double, rocblas_dgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER(std::complex, rocblas_zgeam) + +#undef OMATCOPY_LAUNCHER -void omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::complex alpha, - sycl::buffer, 1> &a, int64_t lda, - sycl::buffer, 1> &b, int64_t ldb) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); +template +void omatcopy2(const char *func_name, Func func, sycl::queue &queue, transpose trans, int64_t m, + int64_t n, T alpha, sycl::buffer &a, int64_t lda, std::int64_t stridea, + sycl::buffer &b, int64_t ldb, std::int64_t strideb) { + throw unimplemented("blas", "omatcopy2", ""); } +#define OMATCOPY2_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + sycl::buffer &a, int64_t lda, int64_t stridea, \ + sycl::buffer &b, int64_t ldb, int64_t strideb) { \ + omatcopy2(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, stridea, lda, \ + b, ldb, strideb); \ + } + +OMATCOPY2_LAUNCHER(float, "unimplemented") +OMATCOPY2_LAUNCHER(double, "unimplemented") +OMATCOPY2_LAUNCHER(std::complex, "unimplemented") +OMATCOPY2_LAUNCHER(std::complex, "unimplemented") + +#undef OMATCOPY2_LAUNCHER + void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, sycl::buffer &ab, int64_t lda, int64_t ldb) { throw unimplemented("blas", "imatcopy", "for row_major layout"); @@ -403,31 +531,27 @@ void imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, std::co throw unimplemented("blas", "imatcopy", "for row_major layout"); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, sycl::buffer &a, int64_t lda, float beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); +template +inline void omatadd(Func func, sycl::queue &queue, transpose transa, transpose transb, int64_t m, + int64_t n, const T alpha, sycl::buffer &a, int64_t lda, const T beta, + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { + column_major::omatadd(func, queue, transa, transb, n, m, alpha, a, lda, beta, b, ldb, c, ldc); } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, sycl::buffer &a, int64_t lda, double beta, - sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +#define OMATADD_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ + void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ + const TYPE alpha, sycl::buffer &a, int64_t lda, const TYPE beta, \ + sycl::buffer &b, int64_t ldb, sycl::buffer &c, int64_t ldc) { \ + omatadd(ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, b, ldb, c, \ + ldc); \ + } -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +OMATADD_LAUNCHER(float, rocblas_sgeam) +OMATADD_LAUNCHER(double, rocblas_dgeam) +OMATADD_LAUNCHER(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER(std::complex, rocblas_zgeam) -void omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, sycl::buffer, 1> &a, int64_t lda, - std::complex beta, sycl::buffer, 1> &b, int64_t ldb, - sycl::buffer, 1> &c, int64_t ldc) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +#undef OMATADD_LAUNCHER // USM APIs @@ -493,32 +617,49 @@ sycl::event gemmt(sycl::queue &queue, uplo upper_lower, transpose transa, transp throw unimplemented("blas", "gemmt", "for row_major layout"); } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, - const float *a, int64_t lda, float *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); +template +inline sycl::event omatcopy(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + const T alpha, const T *a, int64_t lda, T *b, int64_t ldb, + const std::vector &dependencies) { + return column_major::omatcopy(func, queue, trans, n, m, alpha, a, lda, b, ldb, dependencies); } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, double alpha, - const double *a, int64_t lda, double *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} +#define OMATCOPY_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, \ + const TYPE alpha, const TYPE *a, int64_t lda, TYPE *b, int64_t ldb, \ + const std::vector &dependencies) { \ + return omatcopy(ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, lda, b, ldb, dependencies); \ + } -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); -} +OMATCOPY_LAUNCHER_USM(float, rocblas_sgeam) +OMATCOPY_LAUNCHER_USM(double, rocblas_dgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATCOPY_LAUNCHER_USM(std::complex, rocblas_zgeam) -sycl::event omatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex *b, int64_t ldb, - const std::vector &dependencies) { - throw unimplemented("blas", "omatcopy", "for row_major layout"); +#undef OMATCOPY_LAUNCHER_USM + +template +sycl::event omatcopy2(const char *func_name, Func func, sycl::queue &queue, transpose trans, + int64_t m, int64_t n, T alpha, const T *a, int64_t lda, int64_t stridea, T *b, + int64_t ldb, int64_t strideb, const std::vector &dependencies) { + throw unimplemented("blas", "omatcopy2", ""); } +#define OMATCOPY2_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatcopy2(sycl::queue &queue, transpose trans, int64_t m, int64_t n, TYPE alpha, \ + const TYPE *a, int64_t lda, int64_t stridea, TYPE *b, int64_t ldb, \ + int64_t strideb, const std::vector &dependencies) { \ + return omatcopy2(#ROCBLAS_ROUTINE, ROCBLAS_ROUTINE, queue, trans, m, n, alpha, a, stridea, \ + lda, b, ldb, strideb, dependencies); \ + } + +OMATCOPY2_LAUNCHER_USM(float, "unimplemented") +OMATCOPY2_LAUNCHER_USM(double, "unimplemented") +OMATCOPY2_LAUNCHER_USM(std::complex, "unimplemented") +OMATCOPY2_LAUNCHER_USM(std::complex, "unimplemented") + +#undef OMATCOPY2_LAUNCHER_USM + sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, float alpha, float *ab, int64_t lda, int64_t ldb, const std::vector &dependencies) { @@ -543,35 +684,30 @@ sycl::event imatcopy(sycl::queue &queue, transpose trans, int64_t m, int64_t n, throw unimplemented("blas", "imatcopy", "for row_major layout"); } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - float alpha, const float *a, int64_t lda, float beta, const float *b, - int64_t ldb, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); +template +inline sycl::event omatadd(Func func, sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, const T alpha, const T *a, int64_t lda, + const T beta, const T *b, int64_t ldb, T *c, int64_t ldc, + const std::vector &dependencies) { + return column_major::omatadd(func, queue, transa, transb, n, m, alpha, a, lda, beta, b, ldb, c, + ldc, dependencies); } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - double alpha, const double *a, int64_t lda, double beta, const double *b, - int64_t ldb, double *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +#define OMATADD_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ + sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, \ + int64_t n, const TYPE alpha, const TYPE *a, int64_t lda, const TYPE beta, \ + const TYPE *b, int64_t ldb, TYPE *c, int64_t ldc, \ + const std::vector &dependencies) { \ + return omatadd(ROCBLAS_ROUTINE, queue, transa, transb, m, n, alpha, a, lda, beta, b, ldb, \ + c, ldc, dependencies); \ + } -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +OMATADD_LAUNCHER_USM(float, rocblas_sgeam) +OMATADD_LAUNCHER_USM(double, rocblas_dgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_cgeam) +OMATADD_LAUNCHER_USM(std::complex, rocblas_zgeam) -sycl::event omatadd(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - std::complex alpha, const std::complex *a, int64_t lda, - std::complex beta, const std::complex *b, int64_t ldb, - std::complex *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "omatadd", "for row_major layout"); -} +#undef OMATADD_LAUNCHER_USM } // namespace row_major } // namespace rocblas diff --git a/src/blas/backends/rocblas/rocblas_helper.hpp b/src/blas/backends/rocblas/rocblas_helper.hpp index 75490e333..ae6301a7a 100644 --- a/src/blas/backends/rocblas/rocblas_helper.hpp +++ b/src/blas/backends/rocblas/rocblas_helper.hpp @@ -27,10 +27,11 @@ #define _ROCBLAS_HELPER_HPP_ -#include +#include #include #include "oneapi/mkl/types.hpp" #include +#include "dtype_string.hpp" namespace oneapi { namespace mkl { @@ -205,6 +206,66 @@ inline rocblas_side get_rocblas_side_mode(oneapi::mkl::side lr) { } } +template +inline rocblas_datatype get_rocblas_datatype() { + static_assert(false); +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f16_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f32_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f64_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f32_c; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_f64_c; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_i8_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_u8_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_i32_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_u32_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype() { + return rocblas_datatype_bf16_r; +} + +template <> +inline rocblas_datatype get_rocblas_datatype>() { + return rocblas_datatype_bf16_c; +} + /*converting std::complex to roc__complex sycl::half to rocblas_half*/ template diff --git a/src/blas/backends/rocblas/rocblas_level1.cpp b/src/blas/backends/rocblas/rocblas_level1.cpp index 1d89d7b83..3a1eacb38 100644 --- a/src/blas/backends/rocblas/rocblas_level1.cpp +++ b/src/blas/backends/rocblas/rocblas_level1.cpp @@ -18,6 +18,7 @@ * limitations under the License. * **************************************************************************/ + #include "rocblas_helper.hpp" #include "rocblas_task.hpp" @@ -32,7 +33,6 @@ namespace column_major { // Buffer APIs -// Level 1 template inline void asum(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, const int64_t incx, sycl::buffer &result) { @@ -69,10 +69,12 @@ inline void asum(Func func, sycl::queue &queue, int64_t n, sycl::buffer & sycl::buffer &result) { \ asum(ROCBLAS_ROUTINE, queue, n, x, incx, result); \ } + ASUM_LAUNCHER(float, float, rocblas_sasum) ASUM_LAUNCHER(double, double, rocblas_dasum) ASUM_LAUNCHER(std::complex, float, rocblas_scasum) ASUM_LAUNCHER(std::complex, double, rocblas_dzasum) + #undef ASUM_LAUNCHER template @@ -81,6 +83,7 @@ inline void scal(Func func, sycl::queue &queue, int64_t n, T1 a, sycl::buffer::Type; using rocDataType2 = typename RocEquivalentType::Type; overflow_check(n, incx); + queue.submit([&](sycl::handler &cgh) { auto x_acc = x.template get_access(cgh); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { @@ -97,12 +100,14 @@ inline void scal(Func func, sycl::queue &queue, int64_t n, T1 a, sycl::buffer &x, int64_t incx) { \ scal(ROCBLAS_ROUTINE, queue, n, a, x, incx); \ } + SCAL_LAUNCHER(float, float, rocblas_sscal) SCAL_LAUNCHER(double, double, rocblas_dscal) SCAL_LAUNCHER(std::complex, std::complex, rocblas_cscal) SCAL_LAUNCHER(std::complex, std::complex, rocblas_zscal) SCAL_LAUNCHER(float, std::complex, rocblas_csscal) SCAL_LAUNCHER(double, std::complex, rocblas_zdscal) + #undef SCAL_LAUNCHER template @@ -110,6 +115,7 @@ inline void axpy(Func func, sycl::queue &queue, int64_t n, T alpha, sycl::buffer int64_t incx, sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto x_acc = x.template get_access(cgh); auto y_acc = y.template get_access(cgh); @@ -135,6 +141,7 @@ AXPY_LAUNCHER(float, rocblas_saxpy) AXPY_LAUNCHER(double, rocblas_daxpy) AXPY_LAUNCHER(std::complex, rocblas_caxpy) AXPY_LAUNCHER(std::complex, rocblas_zaxpy) + #undef AXPY_LAUNCHER void axpby(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, @@ -164,6 +171,7 @@ inline void rotg(Func func, sycl::queue &queue, sycl::buffer &a, sycl::bu sycl::buffer &c, sycl::buffer &s) { using rocDataType1 = typename RocEquivalentType::Type; using rocDataType2 = typename RocEquivalentType::Type; + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -201,6 +209,7 @@ ROTG_LAUNCHER(float, float, rocblas_srotg) ROTG_LAUNCHER(double, double, rocblas_drotg) ROTG_LAUNCHER(std::complex, float, rocblas_crotg) ROTG_LAUNCHER(std::complex, double, rocblas_zrotg) + #undef ROTG_LAUNCHER template @@ -208,6 +217,7 @@ inline void rotm(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x sycl::buffer &y, int64_t incy, sycl::buffer ¶m) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto x_acc = x.template get_access(cgh); auto y_acc = y.template get_access(cgh); @@ -242,6 +252,7 @@ inline void rotm(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x ROTM_LAUNCHER(float, rocblas_srotm) ROTM_LAUNCHER(double, rocblas_drotm) + #undef ROTM_LAUNCHER template @@ -249,6 +260,7 @@ inline void copy(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto x_acc = x.template get_access(cgh); auto y_acc = y.template get_access(cgh); @@ -273,6 +285,7 @@ COPY_LAUNCHER(float, rocblas_scopy) COPY_LAUNCHER(double, rocblas_dcopy) COPY_LAUNCHER(std::complex, rocblas_ccopy) COPY_LAUNCHER(std::complex, rocblas_zcopy) + #undef COPY_LAUNCHER template @@ -280,6 +293,7 @@ inline void dot(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, sycl::buffer &y, int64_t incy, sycl::buffer &result) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto x_acc = x.template get_access(cgh); auto y_acc = y.template get_access(cgh); @@ -311,14 +325,21 @@ inline void dot(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, sycl::buffer &y, const int64_t incy, sycl::buffer &result) { \ dot(ROCBLAS_ROUTINE, queue, n, x, incx, y, incy, result); \ } + DOT_LAUNCHER(, float, rocblas_sdot) DOT_LAUNCHER(, double, rocblas_ddot) -DOT_LAUNCHER(c, std::complex, rocblas_cdotc) -DOT_LAUNCHER(c, std::complex, rocblas_zdotc) DOT_LAUNCHER(u, std::complex, rocblas_cdotu) +DOT_LAUNCHER(c, std::complex, rocblas_cdotc) DOT_LAUNCHER(u, std::complex, rocblas_zdotu) +DOT_LAUNCHER(c, std::complex, rocblas_zdotc) + #undef DOT_LAUNCHER +void dot(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, + sycl::buffer &y, int64_t incy, sycl::buffer &result) { + throw unimplemented("blas", "dot", "for column_major layout"); +} + template inline void rot(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, const int64_t incx, sycl::buffer &y, int64_t incy, T2 c, T3 s) { @@ -326,6 +347,7 @@ inline void rot(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x using rocDataType2 = typename RocEquivalentType::Type; using rocDataType3 = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto x_acc = x.template get_access(cgh); auto y_acc = y.template get_access(cgh); @@ -356,11 +378,13 @@ ROT_LAUNCHER(float, float, float, rocblas_srot) ROT_LAUNCHER(double, double, double, rocblas_drot) ROT_LAUNCHER(std::complex, float, float, rocblas_csrot) ROT_LAUNCHER(std::complex, double, double, rocblas_zdrot) + #undef ROT_LAUNCHER void sdsdot(sycl::queue &queue, int64_t n, float sb, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &result) { overflow_check(n, incx, incy); + // rocBLAS does not support sdot so we need to mimic sdot. queue.submit([&](sycl::handler &cgh) { auto x_acc = x.get_access(cgh); @@ -386,21 +410,18 @@ void sdsdot(sycl::queue &queue, int64_t n, float sb, sycl::buffer &x, rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host); }); }); + // Since SB is a host pointer we need to bring the result back to the host and // add sb to it. result.get_access()[0] += sb; } -void dot(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - sycl::buffer &y, int64_t incy, sycl::buffer &result) { - throw unimplemented("blas", "dot", "for column_major layout"); -} - template inline void rotmg(Func func, sycl::queue &queue, sycl::buffer &d1, sycl::buffer &d2, sycl::buffer &x1, T y1, sycl::buffer ¶m) { using rocDataType = typename RocEquivalentType::Type; sycl::buffer y1_buff(&y1, sycl::range<1>(1)); + queue.submit([&](sycl::handler &cgh) { auto d1_acc = d1.template get_access(cgh); auto d2_acc = d2.template get_access(cgh); @@ -439,6 +460,7 @@ inline void rotmg(Func func, sycl::queue &queue, sycl::buffer &d1, sycl::b ROTMG_LAUNCHER(float, rocblas_srotmg) ROTMG_LAUNCHER(double, rocblas_drotmg) + #undef ROTMG_LAUNCHER template @@ -446,6 +468,7 @@ inline void iamax(Func func, sycl::queue &queue, int64_t n, sycl::buffer & const int64_t incx, sycl::buffer &result) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); + // rocBLAS does not support int64_t as return type for the data by default. So we need to // mimic iamax. We are converting the result to be the int and then we convert // it back to the actual data on the host. @@ -478,10 +501,13 @@ inline void iamax(Func func, sycl::queue &queue, int64_t n, sycl::buffer & rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host); }); }); - // This requires to bring the data to host, copy it, and return it back to - // the device - result.template get_access()[0] = std::max( - (int64_t)int_res_buff.template get_access()[0] - 1, int64_t{ 0 }); + + queue.submit([&](sycl::handler &cgh) { + auto int_res_acc = int_res_buff.template get_access(cgh); + auto result_acc = result.template get_access(cgh); + cgh.single_task( + [=]() { result_acc[0] = std::max((int64_t)int_res_acc[0] - 1, (int64_t)0); }); + }); } #define IAMAX_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -489,10 +515,12 @@ inline void iamax(Func func, sycl::queue &queue, int64_t n, sycl::buffer & sycl::buffer &result) { \ iamax(ROCBLAS_ROUTINE, queue, n, x, incx, result); \ } + IAMAX_LAUNCHER(float, rocblas_isamax) IAMAX_LAUNCHER(double, rocblas_idamax) IAMAX_LAUNCHER(std::complex, rocblas_icamax) IAMAX_LAUNCHER(std::complex, rocblas_izamax) + #undef IAMAX_LAUNCHER template @@ -500,6 +528,7 @@ inline void swap(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto x_acc = x.template get_access(cgh); auto y_acc = y.template get_access(cgh); @@ -524,6 +553,7 @@ SWAP_LAUNCHER(float, rocblas_sswap) SWAP_LAUNCHER(double, rocblas_dswap) SWAP_LAUNCHER(std::complex, rocblas_cswap) SWAP_LAUNCHER(std::complex, rocblas_zswap) + #undef SWAP_LAUNCHER template @@ -531,6 +561,7 @@ inline void iamin(Func func, sycl::queue &queue, int64_t n, sycl::buffer & const int64_t incx, sycl::buffer &result) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); + // rocBLAS does not support int64_t as return type for the data by default. So we need to // mimic iamin we are converting the result to be the int and then we convert // it back to the actual data on the host. @@ -563,8 +594,13 @@ inline void iamin(Func func, sycl::queue &queue, int64_t n, sycl::buffer & rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host); }); }); - result.template get_access()[0] = std::max( - (int64_t)int_res_buff.template get_access()[0] - 1, int64_t{ 0 }); + + queue.submit([&](sycl::handler &cgh) { + auto int_res_acc = int_res_buff.template get_access(cgh); + auto result_acc = result.template get_access(cgh); + cgh.single_task( + [=]() { result_acc[0] = std::max((int64_t)int_res_acc[0] - 1, (int64_t)0); }); + }); } #define IAMIN_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -572,10 +608,12 @@ inline void iamin(Func func, sycl::queue &queue, int64_t n, sycl::buffer & sycl::buffer &result) { \ iamin(ROCBLAS_ROUTINE, queue, n, x, incx, result); \ } + IAMIN_LAUNCHER(float, rocblas_isamin) IAMIN_LAUNCHER(double, rocblas_idamin) IAMIN_LAUNCHER(std::complex, rocblas_icamin) IAMIN_LAUNCHER(std::complex, rocblas_izamin) + #undef IAMIN_LAUNCHER template @@ -615,15 +653,16 @@ inline void nrm2(Func func, sycl::queue &queue, int64_t n, sycl::buffer & sycl::buffer &result) { \ nrm2(ROCBLAS_ROUTINE, queue, n, x, incx, result); \ } + NRM2_LAUNCHER(float, float, rocblas_snrm2) NRM2_LAUNCHER(double, double, rocblas_dnrm2) NRM2_LAUNCHER(std::complex, float, rocblas_scnrm2) NRM2_LAUNCHER(std::complex, double, rocblas_dznrm2) + #undef NRM2_LAUNCHER // USM APIs -// Level 1 template inline sycl::event asum(Func func, sycl::queue &queue, int64_t n, const T1 *x, const int64_t incx, T2 *result, const std::vector &dependencies) { @@ -632,10 +671,7 @@ inline sycl::event asum(Func func, sycl::queue &queue, int64_t n, const T1 *x, c overflow_check(n, incx); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device); @@ -648,6 +684,7 @@ inline sycl::event asum(Func func, sycl::queue &queue, int64_t n, const T1 *x, c rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host); }); }); + return done; } @@ -656,10 +693,12 @@ inline sycl::event asum(Func func, sycl::queue &queue, int64_t n, const T1 *x, c TYPE2 *result, const std::vector &dependencies) { \ return asum(ROCBLAS_ROUTINE, queue, n, x, incx, result, dependencies); \ } + ASUM_LAUNCHER_USM(float, float, rocblas_sasum) ASUM_LAUNCHER_USM(double, double, rocblas_dasum) ASUM_LAUNCHER_USM(std::complex, float, rocblas_scasum) ASUM_LAUNCHER_USM(std::complex, double, rocblas_dzasum) + #undef ASUM_LAUNCHER_USM template @@ -668,11 +707,9 @@ inline sycl::event scal(Func func, sycl::queue &queue, int64_t n, T1 a, T2 *x, i using rocDataType1 = typename RocEquivalentType::Type; using rocDataType2 = typename RocEquivalentType::Type; overflow_check(n, incx); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -682,6 +719,7 @@ inline sycl::event scal(Func func, sycl::queue &queue, int64_t n, T1 a, T2 *x, i ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, (rocDataType1 *)&a, x_, std::abs(incx)); }); }); + return done; } @@ -690,12 +728,14 @@ inline sycl::event scal(Func func, sycl::queue &queue, int64_t n, T1 a, T2 *x, i const std::vector &dependencies) { \ return scal(ROCBLAS_ROUTINE, queue, n, a, x, incx, dependencies); \ } + SCAL_LAUNCHER_USM(float, float, rocblas_sscal) SCAL_LAUNCHER_USM(double, double, rocblas_dscal) SCAL_LAUNCHER_USM(std::complex, std::complex, rocblas_cscal) SCAL_LAUNCHER_USM(std::complex, std::complex, rocblas_zscal) SCAL_LAUNCHER_USM(float, std::complex, rocblas_csscal) SCAL_LAUNCHER_USM(double, std::complex, rocblas_zdscal) + #undef SCAL_LAUNCHER_USM template @@ -703,11 +743,9 @@ inline sycl::event axpy(Func func, sycl::queue &queue, int64_t n, T alpha, const T *y, int64_t incy, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -718,6 +756,7 @@ inline sycl::event axpy(Func func, sycl::queue &queue, int64_t n, T alpha, const incy); }); }); + return done; } @@ -731,6 +770,7 @@ AXPY_LAUNCHER_USM(float, rocblas_saxpy) AXPY_LAUNCHER_USM(double, rocblas_daxpy) AXPY_LAUNCHER_USM(std::complex, rocblas_caxpy) AXPY_LAUNCHER_USM(std::complex, rocblas_zaxpy) + #undef AXPY_LAUNCHER_USM sycl::event axpby(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, @@ -761,11 +801,9 @@ inline sycl::event rotg(Func func, sycl::queue &queue, T1 *a, T1 *b, T2 *c, T1 * const std::vector &dependencies) { using rocDataType1 = typename RocEquivalentType::Type; using rocDataType2 = typename RocEquivalentType::Type; + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -777,6 +815,7 @@ inline sycl::event rotg(Func func, sycl::queue &queue, T1 *a, T1 *b, T2 *c, T1 * ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, a_, b_, c_, s_); }); }); + return done; } @@ -790,6 +829,7 @@ ROTG_LAUNCHER_USM(float, float, rocblas_srotg) ROTG_LAUNCHER_USM(double, double, rocblas_drotg) ROTG_LAUNCHER_USM(std::complex, float, rocblas_crotg) ROTG_LAUNCHER_USM(std::complex, double, rocblas_zrotg) + #undef ROTG_LAUNCHER_USM template @@ -797,11 +837,9 @@ inline sycl::event rotm(Func func, sycl::queue &queue, int64_t n, T *x, int64_t int64_t incy, T *param, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -812,6 +850,7 @@ inline sycl::event rotm(Func func, sycl::queue &queue, int64_t n, T *x, int64_t ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, x_, incx, y_, incy, param_); }); }); + return done; } @@ -823,6 +862,7 @@ inline sycl::event rotm(Func func, sycl::queue &queue, int64_t n, T *x, int64_t ROTM_LAUNCHER_USM(float, rocblas_srotm) ROTM_LAUNCHER_USM(double, rocblas_drotm) + #undef ROTM_LAUNCHER_USM template @@ -830,11 +870,9 @@ inline sycl::event copy(Func func, sycl::queue &queue, int64_t n, const T *x, in int64_t incy, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -844,6 +882,7 @@ inline sycl::event copy(Func func, sycl::queue &queue, int64_t n, const T *x, in ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, x_, incx, y_, incy); }); }); + return done; } @@ -857,6 +896,7 @@ COPY_LAUNCHER_USM(float, rocblas_scopy) COPY_LAUNCHER_USM(double, rocblas_dcopy) COPY_LAUNCHER_USM(std::complex, rocblas_ccopy) COPY_LAUNCHER_USM(std::complex, rocblas_zcopy) + #undef COPY_LAUNCHER_USM template @@ -865,11 +905,9 @@ inline sycl::event dot(Func func, sycl::queue &queue, int64_t n, const T *x, con const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -880,6 +918,7 @@ inline sycl::event dot(Func func, sycl::queue &queue, int64_t n, const T *x, con ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, x_, incx, y_, incy, res_); }); }); + return done; } @@ -889,14 +928,21 @@ inline sycl::event dot(Func func, sycl::queue &queue, int64_t n, const T *x, con const std::vector &dependencies) { \ return dot(ROCBLAS_ROUTINE, queue, n, x, incx, y, incy, result, dependencies); \ } + DOT_LAUNCHER_USM(, float, rocblas_sdot) DOT_LAUNCHER_USM(, double, rocblas_ddot) -DOT_LAUNCHER_USM(c, std::complex, rocblas_cdotc) -DOT_LAUNCHER_USM(c, std::complex, rocblas_zdotc) DOT_LAUNCHER_USM(u, std::complex, rocblas_cdotu) +DOT_LAUNCHER_USM(c, std::complex, rocblas_cdotc) DOT_LAUNCHER_USM(u, std::complex, rocblas_zdotu) +DOT_LAUNCHER_USM(c, std::complex, rocblas_zdotc) + #undef DOT_LAUNCHER_USM +sycl::event dot(sycl::queue &queue, int64_t n, const float *x, int64_t incx, const float *y, + int64_t incy, double *result, const std::vector &dependencies) { + throw unimplemented("blas", "dot", "for column_major layout"); +} + template inline sycl::event rot(Func func, sycl::queue &queue, int64_t n, T1 *x, const int64_t incx, T1 *y, int64_t incy, T2 c, T3 s, const std::vector &dependencies) { @@ -904,11 +950,9 @@ inline sycl::event rot(Func func, sycl::queue &queue, int64_t n, T1 *x, const in using rocDataType2 = typename RocEquivalentType::Type; using rocDataType3 = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -919,6 +963,7 @@ inline sycl::event rot(Func func, sycl::queue &queue, int64_t n, T1 *x, const in (rocDataType3 *)&s); }); }); + return done; } @@ -933,18 +978,17 @@ ROT_LAUNCHER_USM(float, float, float, rocblas_srot) ROT_LAUNCHER_USM(double, double, double, rocblas_drot) ROT_LAUNCHER_USM(std::complex, float, float, rocblas_csrot) ROT_LAUNCHER_USM(std::complex, double, double, rocblas_zdrot) + #undef ROT_LAUNCHER_USM sycl::event sdsdot(sycl::queue &queue, int64_t n, float sb, const float *x, int64_t incx, const float *y, int64_t incy, float *result, const std::vector &dependencies) { overflow_check(n, incx, incy); + // rocBLAS does not support sdot so we need to mimic sdot. auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -955,25 +999,19 @@ sycl::event sdsdot(sycl::queue &queue, int64_t n, float sb, const float *x, int6 ROCBLAS_ERROR_FUNC_SYNC(rocblas_sdot, err, handle, n, x_, incx, y_, incy, res_); }); }); - done.wait(); + + done.wait_and_throw(); result[0] = result[0] + sb; return done; } -sycl::event dot(sycl::queue &queue, int64_t n, const float *x, int64_t incx, const float *y, - int64_t incy, double *result, const std::vector &dependencies) { - throw unimplemented("blas", "dot", "for column_major layout"); -} - template inline sycl::event rotmg(Func func, sycl::queue &queue, T *d1, T *d2, T *x1, T y1, T *param, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -986,6 +1024,7 @@ inline sycl::event rotmg(Func func, sycl::queue &queue, T *d1, T *d2, T *x1, T y ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, d1_, d2_, x1_, y1_, param_); }); }); + return done; } @@ -997,6 +1036,7 @@ inline sycl::event rotmg(Func func, sycl::queue &queue, T *d1, T *d2, T *x1, T y ROTMG_LAUNCHER_USM(float, rocblas_srotmg) ROTMG_LAUNCHER_USM(double, rocblas_drotmg) + #undef ROTMG_LAUNCHER_USM template @@ -1012,11 +1052,9 @@ inline sycl::event iamax(Func func, sycl::queue &queue, int64_t n, const T *x, c auto int_res_p = (int *)sycl::aligned_alloc_shared(64, sizeof(rocblas_int), queue.get_device(), queue.get_context()); *int_res_p = 0; + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device); @@ -1029,7 +1067,8 @@ inline sycl::event iamax(Func func, sycl::queue &queue, int64_t n, const T *x, c rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host); }); }); - done.wait(); + + done.wait_and_throw(); result[0] = std::max((int64_t)(*int_res_p - 1), int64_t{ 0 }); return done; } @@ -1039,10 +1078,12 @@ inline sycl::event iamax(Func func, sycl::queue &queue, int64_t n, const T *x, c int64_t *result, const std::vector &dependencies) { \ return iamax(ROCBLAS_ROUTINE, queue, n, x, incx, result, dependencies); \ } + IAMAX_LAUNCHER_USM(float, rocblas_isamax) IAMAX_LAUNCHER_USM(double, rocblas_idamax) IAMAX_LAUNCHER_USM(std::complex, rocblas_icamax) IAMAX_LAUNCHER_USM(std::complex, rocblas_izamax) + #undef IAMAX_LAUNCHER_USM template @@ -1050,11 +1091,9 @@ inline sycl::event swap(Func func, sycl::queue &queue, int64_t n, T *x, int64_t int64_t incy, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1064,6 +1103,7 @@ inline sycl::event swap(Func func, sycl::queue &queue, int64_t n, T *x, int64_t ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, n, x_, incx, y_, incy); }); }); + return done; } @@ -1077,6 +1117,7 @@ SWAP_LAUNCHER_USM(float, rocblas_sswap) SWAP_LAUNCHER_USM(double, rocblas_dswap) SWAP_LAUNCHER_USM(std::complex, rocblas_cswap) SWAP_LAUNCHER_USM(std::complex, rocblas_zswap) + #undef SWAP_LAUNCHER_USM template @@ -1092,11 +1133,9 @@ inline sycl::event iamin(Func func, sycl::queue &queue, int64_t n, const T *x, c auto int_res_p = (int *)sycl::aligned_alloc_shared(64, sizeof(rocblas_int), queue.get_device(), queue.get_context()); *int_res_p = 0; + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device); @@ -1110,7 +1149,8 @@ inline sycl::event iamin(Func func, sycl::queue &queue, int64_t n, const T *x, c rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host); }); }); - done.wait(); + + done.wait_and_throw(); result[0] = std::max((int64_t)(*int_res_p - 1), int64_t{ 0 }); return done; } @@ -1120,10 +1160,12 @@ inline sycl::event iamin(Func func, sycl::queue &queue, int64_t n, const T *x, c int64_t *result, const std::vector &dependencies) { \ return iamin(ROCBLAS_ROUTINE, queue, n, x, incx, result, dependencies); \ } + IAMIN_LAUNCHER_USM(float, rocblas_isamin) IAMIN_LAUNCHER_USM(double, rocblas_idamin) IAMIN_LAUNCHER_USM(std::complex, rocblas_icamin) IAMIN_LAUNCHER_USM(std::complex, rocblas_izamin) + #undef IAMIN_LAUNCHER_USM template @@ -1134,10 +1176,7 @@ inline sycl::event nrm2(Func func, sycl::queue &queue, int64_t n, const T1 *x, c overflow_check(n, incx); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); rocblas_set_pointer_mode(handle, rocblas_pointer_mode_device); @@ -1150,6 +1189,7 @@ inline sycl::event nrm2(Func func, sycl::queue &queue, int64_t n, const T1 *x, c rocblas_set_pointer_mode(handle, rocblas_pointer_mode_host); }); }); + return done; } @@ -1158,10 +1198,12 @@ inline sycl::event nrm2(Func func, sycl::queue &queue, int64_t n, const T1 *x, c TYPE2 *result, const std::vector &dependencies) { \ return nrm2(ROCBLAS_ROUTINE, queue, n, x, incx, result, dependencies); \ } + NRM2_LAUNCHER_USM(float, float, rocblas_snrm2) NRM2_LAUNCHER_USM(double, double, rocblas_dnrm2) NRM2_LAUNCHER_USM(std::complex, float, rocblas_scnrm2) NRM2_LAUNCHER_USM(std::complex, double, rocblas_dznrm2) + #undef NRM2_LAUNCHER_USM } // namespace column_major @@ -1169,11 +1211,10 @@ namespace row_major { // Buffer APIs -// Level 1 template inline void asum(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, const int64_t incx, sycl::buffer &result) { - throw unimplemented("blas", "asum", "for row_major layout"); + column_major::asum(func, queue, n, x, incx, result); } #define ASUM_LAUNCHER(TYPE1, TYPE2, ROCBLAS_ROUTINE) \ @@ -1181,34 +1222,38 @@ inline void asum(Func func, sycl::queue &queue, int64_t n, sycl::buffer & sycl::buffer &result) { \ asum(ROCBLAS_ROUTINE, queue, n, x, incx, result); \ } + ASUM_LAUNCHER(float, float, rocblas_sasum) ASUM_LAUNCHER(double, double, rocblas_dasum) ASUM_LAUNCHER(std::complex, float, rocblas_scasum) ASUM_LAUNCHER(std::complex, double, rocblas_dzasum) + #undef ASUM_LAUNCHER template inline void scal(Func func, sycl::queue &queue, int64_t n, T1 a, sycl::buffer &x, int64_t incx) { - throw unimplemented("blas", "scal", "for row_major layout"); + column_major::scal(func, queue, n, a, x, incx); } #define SCAL_LAUNCHER(TYPE1, TYPE2, ROCBLAS_ROUTINE) \ void scal(sycl::queue &queue, int64_t n, TYPE1 a, sycl::buffer &x, int64_t incx) { \ scal(ROCBLAS_ROUTINE, queue, n, a, x, incx); \ } + SCAL_LAUNCHER(float, float, rocblas_sscal) SCAL_LAUNCHER(double, double, rocblas_dscal) SCAL_LAUNCHER(std::complex, std::complex, rocblas_cscal) SCAL_LAUNCHER(std::complex, std::complex, rocblas_zscal) SCAL_LAUNCHER(float, std::complex, rocblas_csscal) SCAL_LAUNCHER(double, std::complex, rocblas_zdscal) + #undef SCAL_LAUNCHER template inline void axpy(Func func, sycl::queue &queue, int64_t n, T alpha, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "axpy", "for row_major layout"); + column_major::axpy(func, queue, n, alpha, x, incx, y, incy); } #define AXPY_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1221,6 +1266,7 @@ AXPY_LAUNCHER(float, rocblas_saxpy) AXPY_LAUNCHER(double, rocblas_daxpy) AXPY_LAUNCHER(std::complex, rocblas_caxpy) AXPY_LAUNCHER(std::complex, rocblas_zaxpy) + #undef AXPY_LAUNCHER void axpby(sycl::queue &queue, int64_t n, float alpha, sycl::buffer &x, int64_t incx, @@ -1248,7 +1294,7 @@ void axpby(sycl::queue &queue, int64_t n, std::complex alpha, template inline void rotg(Func func, sycl::queue &queue, sycl::buffer &a, sycl::buffer &b, sycl::buffer &c, sycl::buffer &s) { - throw unimplemented("blas", "rotg", "for row_major layout"); + column_major::rotg(func, queue, a, b, c, s); } #define ROTG_LAUNCHER(TYPE1, TYPE2, ROCBLAS_ROUTINE) \ @@ -1261,12 +1307,13 @@ ROTG_LAUNCHER(float, float, rocblas_srotg) ROTG_LAUNCHER(double, double, rocblas_drotg) ROTG_LAUNCHER(std::complex, float, rocblas_crotg) ROTG_LAUNCHER(std::complex, double, rocblas_zrotg) + #undef ROTG_LAUNCHER template inline void rotm(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer ¶m) { - throw unimplemented("blas", "rotm", "for row_major layout"); + column_major::rotm(func, queue, n, x, incx, y, incy, param); } #define ROTM_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1277,12 +1324,13 @@ inline void rotm(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x ROTM_LAUNCHER(float, rocblas_srotm) ROTM_LAUNCHER(double, rocblas_drotm) + #undef ROTM_LAUNCHER template inline void copy(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "copy", "for row_major layout"); + column_major::copy(func, queue, n, x, incx, y, incy); } #define COPY_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1295,12 +1343,13 @@ COPY_LAUNCHER(float, rocblas_scopy) COPY_LAUNCHER(double, rocblas_dcopy) COPY_LAUNCHER(std::complex, rocblas_ccopy) COPY_LAUNCHER(std::complex, rocblas_zcopy) + #undef COPY_LAUNCHER template inline void dot(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, const int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &result) { - throw unimplemented("blas", "dot", "for row_major layout"); + column_major::dot(func, queue, n, x, incx, y, incy, result); } #define DOT_LAUNCHER(EXT, TYPE, ROCBLAS_ROUTINE) \ @@ -1308,18 +1357,25 @@ inline void dot(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, sycl::buffer &y, const int64_t incy, sycl::buffer &result) { \ dot(ROCBLAS_ROUTINE, queue, n, x, incx, y, incy, result); \ } + DOT_LAUNCHER(, float, rocblas_sdot) DOT_LAUNCHER(, double, rocblas_ddot) -DOT_LAUNCHER(c, std::complex, rocblas_cdotc) -DOT_LAUNCHER(c, std::complex, rocblas_zdotc) DOT_LAUNCHER(u, std::complex, rocblas_cdotu) +DOT_LAUNCHER(c, std::complex, rocblas_cdotc) DOT_LAUNCHER(u, std::complex, rocblas_zdotu) +DOT_LAUNCHER(c, std::complex, rocblas_zdotc) + #undef DOT_LAUNCHER +void dot(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, + sycl::buffer &y, int64_t incy, sycl::buffer &result) { + throw unimplemented("blas", "dot", "for row_major layout"); +} + template inline void rot(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, const int64_t incx, sycl::buffer &y, int64_t incy, T2 c, T3 s) { - throw unimplemented("blas", "rot", "for row_major layout"); + column_major::rot(func, queue, n, x, incx, y, incy, c, s); } #define ROT_LAUNCHER(TYPE1, TYPE2, TYPE3, ROCBLAS_ROUTINE) \ @@ -1332,22 +1388,18 @@ ROT_LAUNCHER(float, float, float, rocblas_srot) ROT_LAUNCHER(double, double, double, rocblas_drot) ROT_LAUNCHER(std::complex, float, float, rocblas_csrot) ROT_LAUNCHER(std::complex, double, double, rocblas_zdrot) + #undef ROT_LAUNCHER void sdsdot(sycl::queue &queue, int64_t n, float sb, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &result) { - throw unimplemented("blas", "sdsdot", "for row_major layout"); -} - -void dot(sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, - sycl::buffer &y, int64_t incy, sycl::buffer &result) { - throw unimplemented("blas", "dot", "for row_major layout"); + column_major::sdsdot(queue, n, sb, x, incx, y, incy, result); } template inline void rotmg(Func func, sycl::queue &queue, sycl::buffer &d1, sycl::buffer &d2, sycl::buffer &x1, T y1, sycl::buffer ¶m) { - throw unimplemented("blas", "rotmg", "for row_major layout"); + column_major::rotmg(func, queue, d1, d2, x1, y1, param); } #define ROTMG_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1358,12 +1410,13 @@ inline void rotmg(Func func, sycl::queue &queue, sycl::buffer &d1, sycl::b ROTMG_LAUNCHER(float, rocblas_srotmg) ROTMG_LAUNCHER(double, rocblas_drotmg) + #undef ROTMG_LAUNCHER template inline void iamax(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, const int64_t incx, sycl::buffer &result) { - throw unimplemented("blas", "iamax", "for row_major layout"); + column_major::iamax(func, queue, n, x, incx, result); } #define IAMAX_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1371,16 +1424,18 @@ inline void iamax(Func func, sycl::queue &queue, int64_t n, sycl::buffer & sycl::buffer &result) { \ iamax(ROCBLAS_ROUTINE, queue, n, x, incx, result); \ } + IAMAX_LAUNCHER(float, rocblas_isamax) IAMAX_LAUNCHER(double, rocblas_idamax) IAMAX_LAUNCHER(std::complex, rocblas_icamax) IAMAX_LAUNCHER(std::complex, rocblas_izamax) + #undef IAMAX_LAUNCHER template inline void swap(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "swap", "for row_major layout"); + column_major::swap(func, queue, n, x, incx, y, incy); } #define SWAP_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1393,12 +1448,13 @@ SWAP_LAUNCHER(float, rocblas_sswap) SWAP_LAUNCHER(double, rocblas_dswap) SWAP_LAUNCHER(std::complex, rocblas_cswap) SWAP_LAUNCHER(std::complex, rocblas_zswap) + #undef SWAP_LAUNCHER template inline void iamin(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, const int64_t incx, sycl::buffer &result) { - throw unimplemented("blas", "iamin", "for row_major layout"); + column_major::iamin(func, queue, n, x, incx, result); } #define IAMIN_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1406,16 +1462,18 @@ inline void iamin(Func func, sycl::queue &queue, int64_t n, sycl::buffer & sycl::buffer &result) { \ iamin(ROCBLAS_ROUTINE, queue, n, x, incx, result); \ } + IAMIN_LAUNCHER(float, rocblas_isamin) IAMIN_LAUNCHER(double, rocblas_idamin) IAMIN_LAUNCHER(std::complex, rocblas_icamin) IAMIN_LAUNCHER(std::complex, rocblas_izamin) + #undef IAMIN_LAUNCHER template inline void nrm2(Func func, sycl::queue &queue, int64_t n, sycl::buffer &x, const int64_t incx, sycl::buffer &result) { - throw unimplemented("blas", "nrm2", "for row_major layout"); + column_major::nrm2(func, queue, n, x, incx, result); } #define NRM2_LAUNCHER(TYPE1, TYPE2, ROCBLAS_ROUTINE) \ @@ -1423,19 +1481,20 @@ inline void nrm2(Func func, sycl::queue &queue, int64_t n, sycl::buffer & sycl::buffer &result) { \ nrm2(ROCBLAS_ROUTINE, queue, n, x, incx, result); \ } + NRM2_LAUNCHER(float, float, rocblas_snrm2) NRM2_LAUNCHER(double, double, rocblas_dnrm2) NRM2_LAUNCHER(std::complex, float, rocblas_scnrm2) NRM2_LAUNCHER(std::complex, double, rocblas_dznrm2) + #undef NRM2_LAUNCHER // USM APIs -// Level 1 template inline sycl::event asum(Func func, sycl::queue &queue, int64_t n, const T1 *x, const int64_t incx, T2 *result, const std::vector &dependencies) { - throw unimplemented("blas", "asum", "for row_major layout"); + return column_major::asum(func, queue, n, x, incx, result, dependencies); } #define ASUM_LAUNCHER_USM(TYPE1, TYPE2, ROCBLAS_ROUTINE) \ @@ -1443,16 +1502,18 @@ inline sycl::event asum(Func func, sycl::queue &queue, int64_t n, const T1 *x, c TYPE2 *result, const std::vector &dependencies) { \ return asum(ROCBLAS_ROUTINE, queue, n, x, incx, result, dependencies); \ } + ASUM_LAUNCHER_USM(float, float, rocblas_sasum) ASUM_LAUNCHER_USM(double, double, rocblas_dasum) ASUM_LAUNCHER_USM(std::complex, float, rocblas_scasum) ASUM_LAUNCHER_USM(std::complex, double, rocblas_dzasum) + #undef ASUM_LAUNCHER_USM template inline sycl::event scal(Func func, sycl::queue &queue, int64_t n, T1 a, T2 *x, int64_t incx, const std::vector &dependencies) { - throw unimplemented("blas", "scal", "for row_major layout"); + return column_major::scal(func, queue, n, a, x, incx, dependencies); } #define SCAL_LAUNCHER_USM(TYPE1, TYPE2, ROCBLAS_ROUTINE) \ @@ -1460,18 +1521,20 @@ inline sycl::event scal(Func func, sycl::queue &queue, int64_t n, T1 a, T2 *x, i const std::vector &dependencies) { \ return scal(ROCBLAS_ROUTINE, queue, n, a, x, incx, dependencies); \ } + SCAL_LAUNCHER_USM(float, float, rocblas_sscal) SCAL_LAUNCHER_USM(double, double, rocblas_dscal) SCAL_LAUNCHER_USM(std::complex, std::complex, rocblas_cscal) SCAL_LAUNCHER_USM(std::complex, std::complex, rocblas_zscal) SCAL_LAUNCHER_USM(float, std::complex, rocblas_csscal) SCAL_LAUNCHER_USM(double, std::complex, rocblas_zdscal) + #undef SCAL_LAUNCHER_USM template inline sycl::event axpy(Func func, sycl::queue &queue, int64_t n, T alpha, const T *x, int64_t incx, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "axpy", "for row_major layout"); + return column_major::axpy(func, queue, n, alpha, x, incx, y, incy, dependencies); } #define AXPY_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1484,6 +1547,7 @@ AXPY_LAUNCHER_USM(float, rocblas_saxpy) AXPY_LAUNCHER_USM(double, rocblas_daxpy) AXPY_LAUNCHER_USM(std::complex, rocblas_caxpy) AXPY_LAUNCHER_USM(std::complex, rocblas_zaxpy) + #undef AXPY_LAUNCHER_USM sycl::event axpby(sycl::queue &queue, int64_t n, float alpha, const float *x, int64_t incx, @@ -1512,7 +1576,7 @@ sycl::event axpby(sycl::queue &queue, int64_t n, std::complex alpha, template inline sycl::event rotg(Func func, sycl::queue &queue, T1 *a, T1 *b, T2 *c, T1 *s, const std::vector &dependencies) { - throw unimplemented("blas", "rotg", "for row_major layout"); + return column_major::rotg(func, queue, a, b, c, s, dependencies); } #define ROTG_LAUNCHER_USM(TYPE1, TYPE2, ROCBLAS_ROUTINE) \ @@ -1525,12 +1589,13 @@ ROTG_LAUNCHER_USM(float, float, rocblas_srotg) ROTG_LAUNCHER_USM(double, double, rocblas_drotg) ROTG_LAUNCHER_USM(std::complex, float, rocblas_crotg) ROTG_LAUNCHER_USM(std::complex, double, rocblas_zrotg) + #undef ROTG_LAUNCHER_USM template inline sycl::event rotm(Func func, sycl::queue &queue, int64_t n, T *x, int64_t incx, T *y, int64_t incy, T *param, const std::vector &dependencies) { - throw unimplemented("blas", "rotm", "for row_major layout"); + return column_major::rotm(func, queue, n, x, incx, y, incy, param, dependencies); } #define ROTM_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1541,12 +1606,13 @@ inline sycl::event rotm(Func func, sycl::queue &queue, int64_t n, T *x, int64_t ROTM_LAUNCHER_USM(float, rocblas_srotm) ROTM_LAUNCHER_USM(double, rocblas_drotm) + #undef ROTM_LAUNCHER_USM template inline sycl::event copy(Func func, sycl::queue &queue, int64_t n, const T *x, int64_t incx, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "copy", "for row_major layout"); + return column_major::copy(func, queue, n, x, incx, y, incy, dependencies); } #define COPY_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1559,13 +1625,14 @@ COPY_LAUNCHER_USM(float, rocblas_scopy) COPY_LAUNCHER_USM(double, rocblas_dcopy) COPY_LAUNCHER_USM(std::complex, rocblas_ccopy) COPY_LAUNCHER_USM(std::complex, rocblas_zcopy) + #undef COPY_LAUNCHER_USM template inline sycl::event dot(Func func, sycl::queue &queue, int64_t n, const T *x, const int64_t incx, const T *y, int64_t incy, T *result, const std::vector &dependencies) { - throw unimplemented("blas", "dot", "for row_major layout"); + return column_major::dot(func, queue, n, x, incx, y, incy, result, dependencies); } #define DOT_LAUNCHER_USM(EXT, TYPE, ROCBLAS_ROUTINE) \ @@ -1574,18 +1641,25 @@ inline sycl::event dot(Func func, sycl::queue &queue, int64_t n, const T *x, con const std::vector &dependencies) { \ return dot(ROCBLAS_ROUTINE, queue, n, x, incx, y, incy, result, dependencies); \ } + DOT_LAUNCHER_USM(, float, rocblas_sdot) DOT_LAUNCHER_USM(, double, rocblas_ddot) -DOT_LAUNCHER_USM(c, std::complex, rocblas_cdotc) -DOT_LAUNCHER_USM(c, std::complex, rocblas_zdotc) DOT_LAUNCHER_USM(u, std::complex, rocblas_cdotu) +DOT_LAUNCHER_USM(c, std::complex, rocblas_cdotc) DOT_LAUNCHER_USM(u, std::complex, rocblas_zdotu) +DOT_LAUNCHER_USM(c, std::complex, rocblas_zdotc) + #undef DOT_LAUNCHER_USM +sycl::event dot(sycl::queue &queue, int64_t n, const float *x, int64_t incx, const float *y, + int64_t incy, double *result, const std::vector &dependencies) { + throw unimplemented("blas", "dot", "for row_major layout"); +} + template inline sycl::event rot(Func func, sycl::queue &queue, int64_t n, T1 *x, const int64_t incx, T1 *y, int64_t incy, T2 c, T3 s, const std::vector &dependencies) { - throw unimplemented("blas", "rot", "for row_major layout"); + return column_major::rot(func, queue, n, x, incx, y, incy, c, s, dependencies); } #define ROT_LAUNCHER_USM(TYPE1, TYPE2, TYPE3, ROCBLAS_ROUTINE) \ @@ -1599,23 +1673,19 @@ ROT_LAUNCHER_USM(float, float, float, rocblas_srot) ROT_LAUNCHER_USM(double, double, double, rocblas_drot) ROT_LAUNCHER_USM(std::complex, float, float, rocblas_csrot) ROT_LAUNCHER_USM(std::complex, double, double, rocblas_zdrot) + #undef ROT_LAUNCHER_USM sycl::event sdsdot(sycl::queue &queue, int64_t n, float sb, const float *x, int64_t incx, const float *y, int64_t incy, float *result, const std::vector &dependencies) { - throw unimplemented("blas", "sdsdot", "for row_major layout"); -} - -sycl::event dot(sycl::queue &queue, int64_t n, const float *x, int64_t incx, const float *y, - int64_t incy, double *result, const std::vector &dependencies) { - throw unimplemented("blas", "dot", "for row_major layout"); + return column_major::sdsdot(queue, n, sb, x, incx, y, incy, result); } template inline sycl::event rotmg(Func func, sycl::queue &queue, T *d1, T *d2, T *x1, T y1, T *param, const std::vector &dependencies) { - throw unimplemented("blas", "rotmg", "for row_major layout"); + return column_major::rotmg(func, queue, d1, d2, x1, y1, param, dependencies); } #define ROTMG_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1626,12 +1696,13 @@ inline sycl::event rotmg(Func func, sycl::queue &queue, T *d1, T *d2, T *x1, T y ROTMG_LAUNCHER_USM(float, rocblas_srotmg) ROTMG_LAUNCHER_USM(double, rocblas_drotmg) + #undef ROTMG_LAUNCHER_USM template inline sycl::event iamax(Func func, sycl::queue &queue, int64_t n, const T *x, const int64_t incx, int64_t *result, const std::vector &dependencies) { - throw unimplemented("blas", "iamax", "for row_major layout"); + return column_major::iamax(func, queue, n, x, incx, result, dependencies); } #define IAMAX_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1639,16 +1710,18 @@ inline sycl::event iamax(Func func, sycl::queue &queue, int64_t n, const T *x, c int64_t *result, const std::vector &dependencies) { \ return iamax(ROCBLAS_ROUTINE, queue, n, x, incx, result, dependencies); \ } + IAMAX_LAUNCHER_USM(float, rocblas_isamax) IAMAX_LAUNCHER_USM(double, rocblas_idamax) IAMAX_LAUNCHER_USM(std::complex, rocblas_icamax) IAMAX_LAUNCHER_USM(std::complex, rocblas_izamax) + #undef IAMAX_LAUNCHER_USM template inline sycl::event swap(Func func, sycl::queue &queue, int64_t n, T *x, int64_t incx, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "swap", "for row_major layout"); + return column_major::swap(func, queue, n, x, incx, y, incy, dependencies); } #define SWAP_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1661,12 +1734,13 @@ SWAP_LAUNCHER_USM(float, rocblas_sswap) SWAP_LAUNCHER_USM(double, rocblas_dswap) SWAP_LAUNCHER_USM(std::complex, rocblas_cswap) SWAP_LAUNCHER_USM(std::complex, rocblas_zswap) + #undef SWAP_LAUNCHER_USM template inline sycl::event iamin(Func func, sycl::queue &queue, int64_t n, const T *x, const int64_t incx, int64_t *result, const std::vector &dependencies) { - throw unimplemented("blas", "iamin", "for row_major layout"); + return column_major::iamin(func, queue, n, x, incx, result, dependencies); } #define IAMIN_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1674,16 +1748,18 @@ inline sycl::event iamin(Func func, sycl::queue &queue, int64_t n, const T *x, c int64_t *result, const std::vector &dependencies) { \ return iamin(ROCBLAS_ROUTINE, queue, n, x, incx, result, dependencies); \ } + IAMIN_LAUNCHER_USM(float, rocblas_isamin) IAMIN_LAUNCHER_USM(double, rocblas_idamin) IAMIN_LAUNCHER_USM(std::complex, rocblas_icamin) IAMIN_LAUNCHER_USM(std::complex, rocblas_izamin) + #undef IAMIN_LAUNCHER_USM template inline sycl::event nrm2(Func func, sycl::queue &queue, int64_t n, const T1 *x, const int64_t incx, T2 *result, const std::vector &dependencies) { - throw unimplemented("blas", "nrm2", "for row_major layout"); + return column_major::nrm2(func, queue, n, x, incx, result, dependencies); } #define NRM2_LAUNCHER_USM(TYPE1, TYPE2, ROCBLAS_ROUTINE) \ @@ -1691,10 +1767,12 @@ inline sycl::event nrm2(Func func, sycl::queue &queue, int64_t n, const T1 *x, c TYPE2 *result, const std::vector &dependencies) { \ return nrm2(ROCBLAS_ROUTINE, queue, n, x, incx, result, dependencies); \ } + NRM2_LAUNCHER_USM(float, float, rocblas_snrm2) NRM2_LAUNCHER_USM(double, double, rocblas_dnrm2) NRM2_LAUNCHER_USM(std::complex, float, rocblas_scnrm2) NRM2_LAUNCHER_USM(std::complex, double, rocblas_dznrm2) + #undef NRM2_LAUNCHER_USM } // namespace row_major diff --git a/src/blas/backends/rocblas/rocblas_level2.cpp b/src/blas/backends/rocblas/rocblas_level2.cpp index 72c4e6a81..882f7ff1c 100644 --- a/src/blas/backends/rocblas/rocblas_level2.cpp +++ b/src/blas/backends/rocblas/rocblas_level2.cpp @@ -18,12 +18,61 @@ * limitations under the License. * **************************************************************************/ + #include "rocblas_helper.hpp" #include "rocblas_task.hpp" #include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/blas/detail/rocblas/onemkl_blas_rocblas.hpp" +// Helper Functions + +template +static inline void conj_vector(sycl::handler &cgh, sycl::buffer &buf, const int64_t len, + const int64_t inc) { + const auto abs_inc = std::abs(inc); + auto acc = buf.template get_access(cgh); + cgh.parallel_for(sycl::range{ (std::size_t)len }, [=](sycl::id<1> id) { + const auto index = id * abs_inc; + acc[index] = std::conj(acc[index]); + }); +} +template +static inline void conj_vector(sycl::handler &cgh, T *ptr, const int64_t len, const int64_t inc) { + const auto abs_inc = std::abs(inc); + cgh.parallel_for(sycl::range{ (std::size_t)len }, [=](sycl::id<1> id) { + const auto index = id * abs_inc; + ptr[index] = std::conj(ptr[index]); + }); +} + +template +static inline void conj_vector(sycl::handler &cgh, sycl::buffer &buf_a, sycl::buffer &buf_b, + const int64_t len, const int64_t inc_a, const int64_t inc_b) { + const auto abs_inc_a = std::abs(inc_a); + const auto abs_inc_b = std::abs(inc_b); + auto acc_a = buf_a.template get_access(cgh); + auto acc_b = buf_b.template get_access(cgh); + cgh.parallel_for(sycl::range{ (std::size_t)len }, [=](sycl::id<1> id) { + const auto index_a = id * abs_inc_a; + const auto index_b = id * abs_inc_b; + acc_a[index_a] = std::conj(acc_a[index_a]); + acc_b[index_b] = std::conj(acc_b[index_b]); + }); +} +template +static inline void conj_vector(sycl::handler &cgh, T *ptr_a, T *ptr_b, const int64_t len, + const int64_t inc_a, const int64_t inc_b) { + const auto abs_inc_a = std::abs(inc_a); + const auto abs_inc_b = std::abs(inc_b); + cgh.parallel_for(sycl::range{ (std::size_t)len }, [=](sycl::id<1> id) { + const auto index_a = id * abs_inc_a; + const auto index_b = id * abs_inc_b; + ptr_a[index_a] = std::conj(ptr_a[index_a]); + ptr_b[index_b] = std::conj(ptr_b[index_b]); + }); +} + namespace oneapi { namespace mkl { namespace blas { @@ -38,6 +87,7 @@ inline void gemv(Func func, sycl::queue &queue, transpose trans, int64_t m, int6 sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, m, lda, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -67,6 +117,7 @@ GEMV_LAUNCHER(float, rocblas_sgemv) GEMV_LAUNCHER(double, rocblas_dgemv) GEMV_LAUNCHER(std::complex, rocblas_cgemv) GEMV_LAUNCHER(std::complex, rocblas_zgemv) + #undef GEMV_LAUNCHER template @@ -75,6 +126,7 @@ inline void gbmv(Func func, sycl::queue &queue, transpose trans, int64_t m, int6 int64_t incx, T beta, sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, m, lda, kl, ku, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -104,6 +156,7 @@ GBMV_LAUNCHER(float, rocblas_sgbmv) GBMV_LAUNCHER(double, rocblas_dgbmv) GBMV_LAUNCHER(std::complex, rocblas_cgbmv) GBMV_LAUNCHER(std::complex, rocblas_zgbmv) + #undef GBMV_LAUNCHER template @@ -112,6 +165,7 @@ inline void ger(Func func, sycl::queue &queue, int64_t m, int64_t n, T alpha, sy int64_t lda) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, m, lda, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -142,6 +196,7 @@ GER_LAUNCHER(u, std::complex, rocblas_cgeru) GER_LAUNCHER(u, std::complex, rocblas_zgeru) GER_LAUNCHER(c, std::complex, rocblas_cgerc) GER_LAUNCHER(c, std::complex, rocblas_zgerc) + #undef GER_LAUNCHER template @@ -150,6 +205,7 @@ inline void hbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, int sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -177,6 +233,7 @@ inline void hbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, int HBMV_LAUNCHER(std::complex, rocblas_chbmv) HBMV_LAUNCHER(std::complex, rocblas_zhbmv) + #undef HBMV_LAUNCHER template @@ -185,6 +242,7 @@ inline void hemv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -212,6 +270,7 @@ inline void hemv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a HEMV_LAUNCHER(std::complex, rocblas_chemv) HEMV_LAUNCHER(std::complex, rocblas_zhemv) + #undef HEMV_LAUNCHER template @@ -255,6 +314,7 @@ inline void her2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a sycl::buffer &a, int64_t lda) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -290,6 +350,7 @@ inline void hpmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -326,6 +387,7 @@ inline void hpr(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, Scal using rocScalarType = typename RocEquivalentType::Type; using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -358,6 +420,7 @@ inline void hpr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a sycl::buffer &a) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -393,6 +456,7 @@ inline void sbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, int sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -429,6 +493,7 @@ inline void symv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -464,6 +529,7 @@ inline void syr(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T al sycl::buffer &x, int64_t incx, sycl::buffer &a, int64_t lda) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -490,6 +556,7 @@ SYR_LAUNCHER(double, rocblas_dsyr) // Intel does not support the following two SYR_LAUNCHER(std::complex, rocblas_csyr) SYR_LAUNCHER(std::complex, rocblas_zsyr) + #undef SYR_LAUNCHER template @@ -498,6 +565,7 @@ inline void syr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a sycl::buffer &a, int64_t lda) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -536,6 +604,7 @@ inline void spmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a sycl::buffer &y, int64_t incy) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -571,6 +640,7 @@ inline void spr(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T al sycl::buffer &x, int64_t incx, sycl::buffer &a) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -603,6 +673,7 @@ inline void spr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a sycl::buffer &a) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -638,6 +709,7 @@ inline void tbmv(Func func, sycl::queue &queue, uplo upper_lower, transpose tran int64_t incx) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, incx); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -674,6 +746,7 @@ inline void tbsv(Func func, sycl::queue &queue, uplo upper_lower, transpose tran int64_t incx) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, incx); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -709,6 +782,7 @@ inline void tpmv(Func func, sycl::queue &queue, uplo upper_lower, transpose tran int64_t n, sycl::buffer &a, sycl::buffer &x, int64_t incx) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -743,6 +817,7 @@ inline void tpsv(Func func, sycl::queue &queue, uplo upper_lower, transpose tran int64_t n, sycl::buffer &a, sycl::buffer &x, int64_t incx) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -778,6 +853,7 @@ inline void trmv(Func func, sycl::queue &queue, uplo upper_lower, transpose tran int64_t incx) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -813,6 +889,7 @@ inline void trsv(Func func, sycl::queue &queue, uplo upper_lower, transpose tran int64_t incx) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto x_acc = x.template get_access(cgh); @@ -849,12 +926,10 @@ inline sycl::event gemv(Func func, sycl::queue &queue, transpose trans, int64_t T alpha, const T *a, int64_t lda, const T *x, int64_t incx, T beta, T *y, int64_t incy, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; - overflow_check(n, m, lda, incx, incy); + overflow_check(m, n, lda, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -867,6 +942,7 @@ inline sycl::event gemv(Func func, sycl::queue &queue, transpose trans, int64_t y_, incy); }); }); + return done; } @@ -882,6 +958,7 @@ GEMV_LAUNCHER_USM(float, rocblas_sgemv) GEMV_LAUNCHER_USM(double, rocblas_dgemv) GEMV_LAUNCHER_USM(std::complex, rocblas_cgemv) GEMV_LAUNCHER_USM(std::complex, rocblas_zgemv) + #undef GEMV_LAUNCHER_USM template @@ -891,11 +968,9 @@ inline sycl::event gbmv(Func func, sycl::queue &queue, transpose trans, int64_t const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, m, lda, kl, ku, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -908,6 +983,7 @@ inline sycl::event gbmv(Func func, sycl::queue &queue, transpose trans, int64_t y_, incy); }); }); + return done; } @@ -924,6 +1000,7 @@ GBMV_LAUNCHER_USM(float, rocblas_sgbmv) GBMV_LAUNCHER_USM(double, rocblas_dgbmv) GBMV_LAUNCHER_USM(std::complex, rocblas_cgbmv) GBMV_LAUNCHER_USM(std::complex, rocblas_zgbmv) + #undef GBMV_LAUNCHER_USM template @@ -932,11 +1009,9 @@ inline sycl::event ger(Func func, sycl::queue &queue, int64_t m, int64_t n, T al const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, m, lda, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -948,6 +1023,7 @@ inline sycl::event ger(Func func, sycl::queue &queue, int64_t m, int64_t n, T al incy, a_, lda); }); }); + return done; } @@ -964,6 +1040,7 @@ GER_LAUNCHER_USM(u, std::complex, rocblas_cgeru) GER_LAUNCHER_USM(u, std::complex, rocblas_zgeru) GER_LAUNCHER_USM(c, std::complex, rocblas_cgerc) GER_LAUNCHER_USM(c, std::complex, rocblas_zgerc) + #undef GER_LAUNCHER_USM template @@ -972,11 +1049,9 @@ inline sycl::event hbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t int64_t incy, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -989,6 +1064,7 @@ inline sycl::event hbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t y_, incy); }); }); + return done; } @@ -1002,6 +1078,7 @@ inline sycl::event hbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t HBMV_LAUNCHER_USM(std::complex, rocblas_chbmv) HBMV_LAUNCHER_USM(std::complex, rocblas_zhbmv) + #undef HBMV_LAUNCHER_USM template @@ -1010,11 +1087,9 @@ inline sycl::event hemv(Func func, sycl::queue &queue, uplo upper_lower, int64_t int64_t incy, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1027,6 +1102,7 @@ inline sycl::event hemv(Func func, sycl::queue &queue, uplo upper_lower, int64_t y_, incy); }); }); + return done; } @@ -1040,6 +1116,7 @@ inline sycl::event hemv(Func func, sycl::queue &queue, uplo upper_lower, int64_t HEMV_LAUNCHER_USM(std::complex, rocblas_chemv) HEMV_LAUNCHER_USM(std::complex, rocblas_zhemv) + #undef HEMV_LAUNCHER_USM template @@ -1051,10 +1128,7 @@ inline sycl::event her(Func func, sycl::queue &queue, uplo upper_lower, int64_t overflow_check(n, lda, incx); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1065,6 +1139,7 @@ inline sycl::event her(Func func, sycl::queue &queue, uplo upper_lower, int64_t (rocScalarType *)&alpha, x_, incx, a_, lda); }); }); + return done; } @@ -1086,11 +1161,9 @@ inline sycl::event her2(Func func, sycl::queue &queue, uplo upper_lower, int64_t const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1102,6 +1175,7 @@ inline sycl::event her2(Func func, sycl::queue &queue, uplo upper_lower, int64_t (rocDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); }); + return done; } @@ -1124,11 +1198,9 @@ inline sycl::event hpmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1141,6 +1213,7 @@ inline sycl::event hpmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t incy); }); }); + return done; } @@ -1164,11 +1237,9 @@ inline sycl::event hpr(Func func, sycl::queue &queue, uplo upper_lower, int64_t using rocScalarType = typename RocEquivalentType::Type; using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1179,6 +1250,7 @@ inline sycl::event hpr(Func func, sycl::queue &queue, uplo upper_lower, int64_t (rocScalarType *)&alpha, x_, incx, a_); }); }); + return done; } @@ -1200,11 +1272,9 @@ inline sycl::event hpr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1216,6 +1286,7 @@ inline sycl::event hpr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t (rocDataType *)&alpha, x_, incx, y_, incy, a_); }); }); + return done; } @@ -1238,11 +1309,9 @@ inline sycl::event sbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t int64_t incy, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1255,6 +1324,7 @@ inline sycl::event sbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t y_, incy); }); }); + return done; } @@ -1277,11 +1347,9 @@ inline sycl::event symv(Func func, sycl::queue &queue, uplo upper_lower, int64_t int64_t incy, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1294,6 +1362,7 @@ inline sycl::event symv(Func func, sycl::queue &queue, uplo upper_lower, int64_t y_, incy); }); }); + return done; } @@ -1316,11 +1385,9 @@ inline sycl::event syr(Func func, sycl::queue &queue, uplo upper_lower, int64_t const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1331,6 +1398,7 @@ inline sycl::event syr(Func func, sycl::queue &queue, uplo upper_lower, int64_t (rocDataType *)&alpha, x_, incx, a_, lda); }); }); + return done; } @@ -1346,6 +1414,7 @@ SYR_LAUNCHER_USM(double, rocblas_dsyr) // Intel does not support the following two SYR_LAUNCHER_USM(std::complex, rocblas_csyr) SYR_LAUNCHER_USM(std::complex, rocblas_zsyr) + #undef SYR_LAUNCHER_USM template @@ -1354,11 +1423,9 @@ inline sycl::event syr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1370,6 +1437,7 @@ inline sycl::event syr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t (rocDataType *)&alpha, x_, incx, y_, incy, a_, lda); }); }); + return done; } @@ -1395,11 +1463,9 @@ inline sycl::event spmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1412,6 +1478,7 @@ inline sycl::event spmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t incy); }); }); + return done; } @@ -1435,10 +1502,7 @@ inline sycl::event spr(Func func, sycl::queue &queue, uplo upper_lower, int64_t using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1449,6 +1513,7 @@ inline sycl::event spr(Func func, sycl::queue &queue, uplo upper_lower, int64_t (rocDataType *)&alpha, x_, incx, a_); }); }); + return done; } @@ -1469,11 +1534,9 @@ inline sycl::event spr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx, incy); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1485,6 +1548,7 @@ inline sycl::event spr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t (rocDataType *)&alpha, x_, incx, y_, incy, a_); }); }); + return done; } @@ -1507,11 +1571,9 @@ inline sycl::event tbmv(Func func, sycl::queue &queue, uplo upper_lower, transpo int64_t incx, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, incx); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1523,6 +1585,7 @@ inline sycl::event tbmv(Func func, sycl::queue &queue, uplo upper_lower, transpo n, k, a_, lda, x_, incx); }); }); + return done; } @@ -1547,11 +1610,9 @@ inline sycl::event tbsv(Func func, sycl::queue &queue, uplo upper_lower, transpo int64_t incx, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, incx); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1563,6 +1624,7 @@ inline sycl::event tbsv(Func func, sycl::queue &queue, uplo upper_lower, transpo n, k, a_, lda, x_, incx); }); }); + return done; } @@ -1587,11 +1649,9 @@ inline sycl::event tpmv(Func func, sycl::queue &queue, uplo upper_lower, transpo const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1603,6 +1663,7 @@ inline sycl::event tpmv(Func func, sycl::queue &queue, uplo upper_lower, transpo n, a_, x_, incx); }); }); + return done; } @@ -1627,11 +1688,9 @@ inline sycl::event tpsv(Func func, sycl::queue &queue, uplo upper_lower, transpo const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, incx); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1643,6 +1702,7 @@ inline sycl::event tpsv(Func func, sycl::queue &queue, uplo upper_lower, transpo n, a_, x_, incx); }); }); + return done; } @@ -1667,11 +1727,9 @@ inline sycl::event trmv(Func func, sycl::queue &queue, uplo upper_lower, transpo const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1683,6 +1741,7 @@ inline sycl::event trmv(Func func, sycl::queue &queue, uplo upper_lower, transpo n, a_, lda, x_, incx); }); }); + return done; } @@ -1707,11 +1766,9 @@ inline sycl::event trsv(Func func, sycl::queue &queue, uplo upper_lower, transpo const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, lda, incx); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -1723,6 +1780,7 @@ inline sycl::event trsv(Func func, sycl::queue &queue, uplo upper_lower, transpo n, a_, lda, x_, incx); }); }); + return done; } @@ -1742,15 +1800,49 @@ TRSV_LAUNCHER_USM(std::complex, rocblas_ztrsv) #undef TRSV_LAUNCHER_USM } // namespace column_major + namespace row_major { // Buffer APIs +template +inline void gemv(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, sycl::buffer, 1> &a, int64_t lda, + sycl::buffer, 1> &x, int64_t incx, std::complex beta, + sycl::buffer, 1> &y, int64_t incy) { + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + alpha = std::conj(alpha); + beta = std::conj(beta); + + if (m > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, m, incx); }); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } + } + } + + column_major::gemv(func, queue, new_trans, n, m, alpha, a, lda, x, incx, beta, y, incy); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } + } +} + template inline void gemv(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx, T beta, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "gemv", "for row_major layout"); + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::gemv(func, queue, new_trans, n, m, alpha, a, lda, x, incx, beta, y, incy); } #define GEMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1764,13 +1856,47 @@ GEMV_LAUNCHER(float, rocblas_sgemv) GEMV_LAUNCHER(double, rocblas_dgemv) GEMV_LAUNCHER(std::complex, rocblas_cgemv) GEMV_LAUNCHER(std::complex, rocblas_zgemv) + #undef GEMV_LAUNCHER +template +inline void gbmv(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, + int64_t ku, std::complex alpha, sycl::buffer, 1> &a, + int64_t lda, sycl::buffer, 1> &x, int64_t incx, + std::complex beta, sycl::buffer, 1> &y, int64_t incy) { + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + alpha = std::conj(alpha); + beta = std::conj(beta); + + if (m > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, m, incx); }); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } + } + } + + column_major::gbmv(func, queue, new_trans, n, m, ku, kl, alpha, a, lda, x, incx, beta, y, incy); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } + } +} + template inline void gbmv(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, int64_t ku, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx, T beta, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "gbmv", "for row_major layout"); + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::gbmv(func, queue, new_trans, n, m, ku, kl, alpha, a, lda, x, incx, beta, y, incy); } #define GBMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1784,35 +1910,70 @@ GBMV_LAUNCHER(float, rocblas_sgbmv) GBMV_LAUNCHER(double, rocblas_dgbmv) GBMV_LAUNCHER(std::complex, rocblas_cgbmv) GBMV_LAUNCHER(std::complex, rocblas_zgbmv) + #undef GBMV_LAUNCHER +template +inline void gerc(Func func, sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy, + sycl::buffer, 1> &a, int64_t lda) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } + + column_major::ger(func, queue, n, m, alpha, y, incy, x, incx, a, lda); +} + +template +inline void geru(Func func, sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, + sycl::buffer, 1> &x, int64_t incx, + sycl::buffer, 1> &y, int64_t incy, + sycl::buffer, 1> &a, int64_t lda) { + column_major::ger(func, queue, n, m, alpha, y, incy, x, incx, a, lda); +} + template inline void ger(Func func, sycl::queue &queue, int64_t m, int64_t n, T alpha, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &a, int64_t lda) { - throw unimplemented("blas", "ger", "for row_major layout"); + column_major::ger(func, queue, n, m, alpha, y, incy, x, incx, a, lda); } #define GER_LAUNCHER(EXT, TYPE, ROCBLAS_ROUTINE) \ void ger##EXT(sycl::queue &queue, int64_t m, int64_t n, TYPE alpha, sycl::buffer &x, \ int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &a, \ int64_t lda) { \ - ger(ROCBLAS_ROUTINE, queue, m, n, alpha, x, incx, y, incy, a, lda); \ + ger##EXT(ROCBLAS_ROUTINE, queue, m, n, alpha, x, incx, y, incy, a, lda); \ } GER_LAUNCHER(, float, rocblas_sger) GER_LAUNCHER(, double, rocblas_dger) GER_LAUNCHER(u, std::complex, rocblas_cgeru) GER_LAUNCHER(u, std::complex, rocblas_zgeru) -GER_LAUNCHER(c, std::complex, rocblas_cgerc) -GER_LAUNCHER(c, std::complex, rocblas_zgerc) +GER_LAUNCHER(c, std::complex, rocblas_cgeru) +GER_LAUNCHER(c, std::complex, rocblas_zgeru) + #undef GER_LAUNCHER template inline void hbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx, T beta, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "hbmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_alpha = std::conj(alpha); + auto new_beta = std::conj(beta); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, y, n, incx, incy); }); + } + + column_major::hbmv(func, queue, new_uplo, n, k, new_alpha, a, lda, x, incx, new_beta, y, incy); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } } #define HBMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1824,13 +1985,27 @@ inline void hbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, int HBMV_LAUNCHER(std::complex, rocblas_chbmv) HBMV_LAUNCHER(std::complex, rocblas_zhbmv) + #undef HBMV_LAUNCHER template inline void hemv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx, T beta, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "hemv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_alpha = std::conj(alpha); + auto new_beta = std::conj(beta); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, y, n, incx, incy); }); + } + + column_major::hemv(func, queue, new_uplo, n, new_alpha, a, lda, x, incx, new_beta, y, incy); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } } #define HEMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1842,13 +2017,21 @@ inline void hemv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T a HEMV_LAUNCHER(std::complex, rocblas_chemv) HEMV_LAUNCHER(std::complex, rocblas_zhemv) + #undef HEMV_LAUNCHER template inline void her(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, ScalarType alpha, sycl::buffer &x, int64_t incx, sycl::buffer &a, int64_t lda) { - throw unimplemented("blas", "her", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + + column_major::her(func, queue, new_uplo, n, alpha, x, incx, a, lda); } #define HER_LAUNCHER(SCALAR_TYPE, DATA_TYPE, ROCBLAS_ROUTINE) \ @@ -1867,7 +2050,14 @@ template inline void her2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &a, int64_t lda) { - throw unimplemented("blas", "her2", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, y, n, incx, incy); }); + } + + column_major::her2(func, queue, new_uplo, n, alpha, y, incy, x, incx, a, lda); } #define HER2_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1886,7 +2076,20 @@ template inline void hpmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &a, sycl::buffer &x, int64_t incx, T beta, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "hpmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_alpha = std::conj(alpha); + auto new_beta = std::conj(beta); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, y, n, incx, incy); }); + } + + column_major::hpmv(func, queue, new_uplo, n, new_alpha, a, x, incx, new_beta, y, incy); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } } #define HPMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1904,7 +2107,14 @@ HPMV_LAUNCHER(std::complex, rocblas_zhpmv) template inline void hpr(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, ScalarType alpha, sycl::buffer &x, int64_t incx, sycl::buffer &a) { - throw unimplemented("blas", "hpr", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + + column_major::hpr(func, queue, new_uplo, n, alpha, x, incx, a); } #define HPR_LAUNCHER(SCALAR_TYPE, DATA_TYPE, ROCBLAS_ROUTINE) \ @@ -1922,7 +2132,14 @@ template inline void hpr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &a) { - throw unimplemented("blas", "hpr2", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, y, n, incx, incy); }); + } + + column_major::hpr2(func, queue, new_uplo, n, alpha, y, incy, x, incx, a); } #define HPR2_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1941,7 +2158,10 @@ template inline void sbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx, T beta, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "sbmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::sbmv(func, queue, new_uplo, n, k, alpha, a, lda, x, incx, beta, y, incy); } #define SBMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1960,7 +2180,10 @@ template inline void symv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx, T beta, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "symv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::symv(func, queue, new_uplo, n, alpha, a, lda, x, incx, beta, y, incy); } #define SYMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1978,7 +2201,10 @@ SYMV_LAUNCHER(double, rocblas_dsymv) template inline void syr(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &x, int64_t incx, sycl::buffer &a, int64_t lda) { - throw unimplemented("blas", "syr", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::syr(func, queue, new_uplo, n, alpha, x, incx, a, lda); } #define SYR_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1992,13 +2218,17 @@ SYR_LAUNCHER(double, rocblas_dsyr) // Intel does not support the following two SYR_LAUNCHER(std::complex, rocblas_csyr) SYR_LAUNCHER(std::complex, rocblas_zsyr) + #undef SYR_LAUNCHER template inline void syr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &a, int64_t lda) { - throw unimplemented("blas", "syr2", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::syr2(func, queue, new_uplo, n, alpha, x, incx, y, incy, a, lda); } #define SYR2_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2020,7 +2250,10 @@ template inline void spmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &a, sycl::buffer &x, int64_t incx, T beta, sycl::buffer &y, int64_t incy) { - throw unimplemented("blas", "spmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::spmv(func, queue, new_uplo, n, alpha, a, x, incx, beta, y, incy); } #define SPMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2038,7 +2271,10 @@ SPMV_LAUNCHER(double, rocblas_dspmv) template inline void spr(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &x, int64_t incx, sycl::buffer &a) { - throw unimplemented("blas", "spr", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::spr(func, queue, new_uplo, n, alpha, x, incx, a); } #define SPR_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2056,7 +2292,10 @@ template inline void spr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, sycl::buffer &x, int64_t incx, sycl::buffer &y, int64_t incy, sycl::buffer &a) { - throw unimplemented("blas", "spr2", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::spr2(func, queue, new_uplo, n, alpha, x, incx, y, incy, a); } #define SPR2_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2071,11 +2310,40 @@ SPR2_LAUNCHER(double, rocblas_dspr2) #undef SPR2_LAUNCHER +template +inline void tbmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + int64_t n, int64_t k, sycl::buffer, 1> &a, int64_t lda, + sycl::buffer, 1> &x, int64_t incx) { + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } + + column_major::tbmv(func, queue, new_uplo, new_trans, unit_diag, n, k, a, lda, x, incx); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } +} + template inline void tbmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, int64_t k, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx) { - throw unimplemented("blas", "tbmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::tbmv(func, queue, new_uplo, new_trans, unit_diag, n, k, a, lda, x, incx); } #define TBMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2092,11 +2360,40 @@ TBMV_LAUNCHER(std::complex, rocblas_ztbmv) #undef TBMV_LAUNCHER +template +inline void tbsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + int64_t n, int64_t k, sycl::buffer, 1> &a, int64_t lda, + sycl::buffer, 1> &x, int64_t incx) { + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } + + column_major::tbsv(func, queue, new_uplo, new_trans, unit_diag, n, k, a, lda, x, incx); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } +} + template inline void tbsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, int64_t k, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx) { - throw unimplemented("blas", "tbsv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::tbsv(func, queue, new_uplo, new_trans, unit_diag, n, k, a, lda, x, incx); } #define TBSV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2113,10 +2410,39 @@ TBSV_LAUNCHER(std::complex, rocblas_ztbsv) #undef TBSV_LAUNCHER +template +inline void tpmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + int64_t n, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, int64_t incx) { + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } + + column_major::tpmv(func, queue, new_uplo, new_trans, unit_diag, n, a, x, incx); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } +} + template inline void tpmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, sycl::buffer &a, sycl::buffer &x, int64_t incx) { - throw unimplemented("blas", "tpmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::tpmv(func, queue, new_uplo, new_trans, unit_diag, n, a, x, incx); } #define TPMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2132,10 +2458,39 @@ TPMV_LAUNCHER(std::complex, rocblas_ztpmv) #undef TPMV_LAUNCHER +template +inline void tpsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + int64_t n, sycl::buffer, 1> &a, + sycl::buffer, 1> &x, int64_t incx) { + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } + + column_major::tpsv(func, queue, new_uplo, new_trans, unit_diag, n, a, x, incx); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } +} + template inline void tpsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, sycl::buffer &a, sycl::buffer &x, int64_t incx) { - throw unimplemented("blas", "tpsv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::tpsv(func, queue, new_uplo, new_trans, unit_diag, n, a, x, incx); } #define TPSV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2151,11 +2506,40 @@ TPSV_LAUNCHER(std::complex, rocblas_ztpsv) #undef TPSV_LAUNCHER +template +inline void trmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + int64_t n, sycl::buffer, 1> &a, int64_t lda, + sycl::buffer, 1> &x, int64_t incx) { + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } + + column_major::trmv(func, queue, new_uplo, new_trans, unit_diag, n, a, lda, x, incx); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } +} + template inline void trmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx) { - throw unimplemented("blas", "trmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::trmv(func, queue, new_uplo, new_trans, unit_diag, n, a, lda, x, incx); } #define TRMV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2171,11 +2555,40 @@ TRMV_LAUNCHER(std::complex, rocblas_ztrmv) #undef TRMV_LAUNCHER +template +inline void trsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, + int64_t n, sycl::buffer, 1> &a, int64_t lda, + sycl::buffer, 1> &x, int64_t incx) { + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } + + column_major::trsv(func, queue, new_uplo, new_trans, unit_diag, n, a, lda, x, incx); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }); + } + } +} + template inline void trsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, sycl::buffer &a, int64_t lda, sycl::buffer &x, int64_t incx) { - throw unimplemented("blas", "trsv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::trsv(func, queue, new_uplo, new_trans, unit_diag, n, a, lda, x, incx); } #define TRSV_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -2193,11 +2606,57 @@ TRSV_LAUNCHER(std::complex, rocblas_ztrsv) // USM APIs +template +inline sycl::event gemv(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + std::complex alpha, const std::complex *a, int64_t lda, + const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { + sycl::event done; + + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + alpha = std::conj(alpha); + beta = std::conj(beta); + + if (m > 0) { + done = queue.submit( + [&](sycl::handler &cgh) { conj_vector(cgh, (std::complex *)x, m, incx); }); + + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } + } + } + + done.wait_and_throw(); + + done = column_major::gemv(func, queue, new_trans, n, m, alpha, a, lda, x, incx, beta, y, incy, + dependencies); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, y, n, incy); + }); + } + } + + return done; +} + template inline sycl::event gemv(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *x, int64_t incx, T beta, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "gemv", "for row_major layout"); + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::gemv(func, queue, new_trans, n, m, alpha, a, lda, x, incx, beta, y, incy, + dependencies); } #define GEMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2212,14 +2671,61 @@ GEMV_LAUNCHER_USM(float, rocblas_sgemv) GEMV_LAUNCHER_USM(double, rocblas_dgemv) GEMV_LAUNCHER_USM(std::complex, rocblas_cgemv) GEMV_LAUNCHER_USM(std::complex, rocblas_zgemv) + #undef GEMV_LAUNCHER_USM +template +inline sycl::event gbmv(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, + int64_t kl, int64_t ku, std::complex alpha, const std::complex *a, + int64_t lda, const std::complex *x, int64_t incx, std::complex beta, + std::complex *y, int64_t incy, + const std::vector &dependencies) { + sycl::event done; + + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + alpha = std::conj(alpha); + beta = std::conj(beta); + + if (m > 0) { + done = queue.submit( + [&](sycl::handler &cgh) { conj_vector(cgh, (std::complex *)x, m, incx); }); + + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, y, n, incy); }); + } + } + } + + done.wait_and_throw(); + + done = column_major::gbmv(func, queue, new_trans, n, m, ku, kl, alpha, a, lda, x, incx, beta, y, + incy, dependencies); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, y, n, incy); + }); + } + } + + return done; +} + template inline sycl::event gbmv(Func func, sycl::queue &queue, transpose trans, int64_t m, int64_t n, int64_t kl, int64_t ku, T alpha, const T *a, int64_t lda, const T *x, int64_t incx, T beta, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "gbmv", "for row_major layout"); + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::gbmv(func, queue, new_trans, n, m, ku, kl, alpha, a, lda, x, incx, beta, y, + incy, dependencies); } #define GBMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2235,35 +2741,81 @@ GBMV_LAUNCHER_USM(float, rocblas_sgbmv) GBMV_LAUNCHER_USM(double, rocblas_dgbmv) GBMV_LAUNCHER_USM(std::complex, rocblas_cgbmv) GBMV_LAUNCHER_USM(std::complex, rocblas_zgbmv) + #undef GBMV_LAUNCHER_USM +template +inline sycl::event gerc(Func func, sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *a, int64_t lda, + const std::vector &dependencies) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, (std::complex *)y, n, incy); }) + .wait_and_throw(); + } + + return column_major::ger(func, queue, n, m, alpha, y, incy, x, incx, a, lda, dependencies); +} + +template +inline sycl::event geru(Func func, sycl::queue &queue, int64_t m, int64_t n, std::complex alpha, + const std::complex *x, int64_t incx, const std::complex *y, + int64_t incy, std::complex *a, int64_t lda, + const std::vector &dependencies) { + return column_major::ger(func, queue, n, m, alpha, y, incy, x, incx, a, lda, dependencies); +} + template inline sycl::event ger(Func func, sycl::queue &queue, int64_t m, int64_t n, T alpha, const T *x, int64_t incx, const T *y, int64_t incy, T *a, int64_t lda, const std::vector &dependencies) { - throw unimplemented("blas", "ger", "for row_major layout"); + return column_major::ger(func, queue, n, m, alpha, y, incy, x, incx, a, lda, dependencies); } -#define GER_LAUNCHER_USM(EXT, TYPE, ROCBLAS_ROUTINE) \ - sycl::event ger##EXT(sycl::queue &queue, int64_t m, int64_t n, TYPE alpha, const TYPE *x, \ - int64_t incx, const TYPE *y, int64_t incy, TYPE *a, int64_t lda, \ - const std::vector &dependencies) { \ - return ger(ROCBLAS_ROUTINE, queue, m, n, alpha, x, incx, y, incy, a, lda, dependencies); \ +#define GER_LAUNCHER_USM(EXT, TYPE, ROCBLAS_ROUTINE) \ + sycl::event ger##EXT(sycl::queue &queue, int64_t m, int64_t n, TYPE alpha, const TYPE *x, \ + int64_t incx, const TYPE *y, int64_t incy, TYPE *a, int64_t lda, \ + const std::vector &dependencies) { \ + return ger##EXT(ROCBLAS_ROUTINE, queue, m, n, alpha, x, incx, y, incy, a, lda, \ + dependencies); \ } GER_LAUNCHER_USM(, float, rocblas_sger) GER_LAUNCHER_USM(, double, rocblas_dger) GER_LAUNCHER_USM(u, std::complex, rocblas_cgeru) GER_LAUNCHER_USM(u, std::complex, rocblas_zgeru) -GER_LAUNCHER_USM(c, std::complex, rocblas_cgerc) -GER_LAUNCHER_USM(c, std::complex, rocblas_zgerc) +GER_LAUNCHER_USM(c, std::complex, rocblas_cgeru) +GER_LAUNCHER_USM(c, std::complex, rocblas_zgeru) + #undef GER_LAUNCHER_USM template inline sycl::event hbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, const T *x, int64_t incx, T beta, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "hbmv", "for row_major layout"); + sycl::event done; + + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_alpha = std::conj(alpha); + auto new_beta = std::conj(beta); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, (T *)x, y, n, incx, incy); }) + .wait_and_throw(); + } + + done = column_major::hbmv(func, queue, new_uplo, n, k, new_alpha, a, lda, x, incx, new_beta, y, + incy, dependencies); + + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, y, n, incy); + }); + } + + return done; } #define HBMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2276,13 +2828,36 @@ inline sycl::event hbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t HBMV_LAUNCHER_USM(std::complex, rocblas_chbmv) HBMV_LAUNCHER_USM(std::complex, rocblas_zhbmv) + #undef HBMV_LAUNCHER_USM template inline sycl::event hemv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *a, int64_t lda, const T *x, int64_t incx, T beta, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "hemv", "for row_major layout"); + sycl::event done; + + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_alpha = std::conj(alpha); + auto new_beta = std::conj(beta); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, (T *)x, y, n, incx, incy); }) + .wait_and_throw(); + } + + done = column_major::hemv(func, queue, new_uplo, n, new_alpha, a, lda, x, incx, new_beta, y, + incy, dependencies); + + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, y, n, incy); + }); + } + + return done; } #define HEMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2295,13 +2870,22 @@ inline sycl::event hemv(Func func, sycl::queue &queue, uplo upper_lower, int64_t HEMV_LAUNCHER_USM(std::complex, rocblas_chemv) HEMV_LAUNCHER_USM(std::complex, rocblas_zhemv) + #undef HEMV_LAUNCHER_USM template inline sycl::event her(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, const ScalarType alpha, const DataType *x, int64_t incx, DataType *a, int64_t lda, const std::vector &dependencies) { - throw unimplemented("blas", "her", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, (DataType *)x, n, incx); }) + .wait_and_throw(); + } + + return column_major::her(func, queue, new_uplo, n, alpha, x, incx, a, lda, dependencies); } #define HER_LAUNCHER_USM(SCALAR_TYPE, DATA_TYPE, ROCBLAS_ROUTINE) \ @@ -2320,7 +2904,16 @@ template inline sycl::event her2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *x, int64_t incx, const T *y, int64_t incy, T *a, int64_t lda, const std::vector &dependencies) { - throw unimplemented("blas", "her2", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, (T *)x, (T *)y, n, incx, incy); }) + .wait_and_throw(); + } + + return column_major::her2(func, queue, new_uplo, n, alpha, y, incy, x, incx, a, lda, + dependencies); } #define HER2_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2340,7 +2933,29 @@ template inline sycl::event hpmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *a, const T *x, int64_t incx, T beta, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "hpmv", "for row_major layout"); + sycl::event done; + + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_alpha = std::conj(alpha); + auto new_beta = std::conj(beta); + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, (T *)x, y, n, incx, incy); }) + .wait_and_throw(); + } + + done = column_major::hpmv(func, queue, new_uplo, n, new_alpha, a, x, incx, new_beta, y, incy, + dependencies); + + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, y, n, incy); + }); + } + + return done; } #define HPMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2360,7 +2975,15 @@ template inline sycl::event hpr(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, const ScalarType alpha, const DataType *x, int64_t incx, DataType *a, const std::vector &dependencies) { - throw unimplemented("blas", "hpr", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, (DataType *)x, n, incx); }) + .wait_and_throw(); + } + + return column_major::hpr(func, queue, new_uplo, n, alpha, x, incx, a, dependencies); } #define HPR_LAUNCHER_USM(SCALAR_TYPE, DATA_TYPE, ROCBLAS_ROUTINE) \ @@ -2379,7 +3002,15 @@ template inline sycl::event hpr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *x, int64_t incx, const T *y, int64_t incy, T *a, const std::vector &dependencies) { - throw unimplemented("blas", "hpr2", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, (T *)x, (T *)y, n, incx, incy); }) + .wait_and_throw(); + } + + return column_major::hpr2(func, queue, new_uplo, n, alpha, y, incy, x, incx, a, dependencies); } #define HPR2_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2399,7 +3030,11 @@ template inline sycl::event sbmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, const T *x, int64_t incx, T beta, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "sbmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::sbmv(func, queue, new_uplo, n, k, alpha, a, lda, x, incx, beta, y, incy, + dependencies); } #define SBMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2419,7 +3054,11 @@ template inline sycl::event symv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *a, int64_t lda, const T *x, int64_t incx, T beta, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "symv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::symv(func, queue, new_uplo, n, alpha, a, lda, x, incx, beta, y, incy, + dependencies); } #define SYMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2439,7 +3078,10 @@ template inline sycl::event syr(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *x, int64_t incx, T *a, int64_t lda, const std::vector &dependencies) { - throw unimplemented("blas", "syr", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::syr(func, queue, new_uplo, n, alpha, x, incx, a, lda, dependencies); } #define SYR_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2454,13 +3096,18 @@ SYR_LAUNCHER_USM(double, rocblas_dsyr) // Intel does not support the following two SYR_LAUNCHER_USM(std::complex, rocblas_csyr) SYR_LAUNCHER_USM(std::complex, rocblas_zsyr) + #undef SYR_LAUNCHER_USM template inline sycl::event syr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *x, int64_t incx, const T *y, int64_t incy, T *a, int64_t lda, const std::vector &dependencies) { - throw unimplemented("blas", "syr2", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::syr2(func, queue, new_uplo, n, alpha, x, incx, y, incy, a, lda, + dependencies); } #define SYR2_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2483,7 +3130,11 @@ template inline sycl::event spmv(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *a, const T *x, int64_t incx, T beta, T *y, int64_t incy, const std::vector &dependencies) { - throw unimplemented("blas", "spmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::spmv(func, queue, new_uplo, n, alpha, a, x, incx, beta, y, incy, + dependencies); } #define SPMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2503,7 +3154,10 @@ template inline sycl::event spr(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *x, int64_t incx, T *a, const std::vector &dependencies) { - throw unimplemented("blas", "spr", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::spr(func, queue, new_uplo, n, alpha, x, incx, a, dependencies); } #define SPR_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2521,7 +3175,10 @@ template inline sycl::event spr2(Func func, sycl::queue &queue, uplo upper_lower, int64_t n, T alpha, const T *x, int64_t incx, const T *y, int64_t incy, T *a, const std::vector &dependencies) { - throw unimplemented("blas", "spr2", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::spr2(func, queue, new_uplo, n, alpha, x, incx, y, incy, a, dependencies); } #define SPR2_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2537,11 +3194,51 @@ SPR2_LAUNCHER_USM(double, rocblas_dspr2) #undef SPR2_LAUNCHER_USM +template +inline sycl::event tbmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, int64_t k, const std::complex *a, int64_t lda, + std::complex *x, int64_t incx, + const std::vector &dependencies) { + sycl::event done; + + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }) + .wait_and_throw(); + } + } + + done = column_major::tbmv(func, queue, new_uplo, new_trans, unit_diag, n, k, a, lda, x, incx, + dependencies); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, x, n, incx); + }); + } + } + + return done; +} + template inline sycl::event tbmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, int64_t k, const T *a, int64_t lda, T *x, int64_t incx, const std::vector &dependencies) { - throw unimplemented("blas", "tbmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::tbmv(func, queue, new_uplo, new_trans, unit_diag, n, k, a, lda, x, incx, + dependencies); } #define TBMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2559,11 +3256,51 @@ TBMV_LAUNCHER_USM(std::complex, rocblas_ztbmv) #undef TBMV_LAUNCHER_USM +template +inline sycl::event tbsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, int64_t k, const std::complex *a, int64_t lda, + std::complex *x, int64_t incx, + const std::vector &dependencies) { + sycl::event done; + + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }) + .wait_and_throw(); + } + } + + done = column_major::tbsv(func, queue, new_uplo, new_trans, unit_diag, n, k, a, lda, x, incx, + dependencies); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, x, n, incx); + }); + } + } + + return done; +} + template inline sycl::event tbsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, int64_t k, const T *a, int64_t lda, T *x, int64_t incx, const std::vector &dependencies) { - throw unimplemented("blas", "tbsv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::tbsv(func, queue, new_uplo, new_trans, unit_diag, n, k, a, lda, x, incx, + dependencies); } #define TBSV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2581,11 +3318,52 @@ TBSV_LAUNCHER_USM(std::complex, rocblas_ztbsv) #undef TBSV_LAUNCHER_USM +template +inline sycl::event tpmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, const std::complex *a, std::complex *x, + int64_t incx, const std::vector &dependencies) { + sycl::event done; + + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }) + .wait_and_throw(); + } + } + + done = column_major::tpmv(func, queue, new_uplo, new_trans, unit_diag, n, a, x, incx, + dependencies); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + incx = std::abs(incx); + + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, x, n, incx); + }); + } + } + + return done; +} + template inline sycl::event tpmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, const T *a, T *x, int64_t incx, const std::vector &dependencies) { - throw unimplemented("blas", "tpmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::tpmv(func, queue, new_uplo, new_trans, unit_diag, n, a, x, incx, + dependencies); } #define TPMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2603,11 +3381,52 @@ TPMV_LAUNCHER_USM(std::complex, rocblas_ztpmv) #undef TPMV_LAUNCHER_USM +template +inline sycl::event tpsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, const std::complex *a, std::complex *x, + int64_t incx, const std::vector &dependencies) { + sycl::event done; + + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }) + .wait_and_throw(); + } + } + + done = column_major::tpsv(func, queue, new_uplo, new_trans, unit_diag, n, a, x, incx, + dependencies); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + incx = std::abs(incx); + + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, x, n, incx); + }); + } + } + + return done; +} + template inline sycl::event tpsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, const T *a, T *x, int64_t incx, const std::vector &dependencies) { - throw unimplemented("blas", "tpsv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::tpsv(func, queue, new_uplo, new_trans, unit_diag, n, a, x, incx, + dependencies); } #define TPSV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2625,11 +3444,51 @@ TPSV_LAUNCHER_USM(std::complex, rocblas_ztpsv) #undef TPSV_LAUNCHER_USM +template +inline sycl::event trmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, const std::complex *a, int64_t lda, + std::complex *x, int64_t incx, + const std::vector &dependencies) { + sycl::event done; + + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }) + .wait_and_throw(); + } + } + + done = column_major::trmv(func, queue, new_uplo, new_trans, unit_diag, n, a, lda, x, incx, + dependencies); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, x, n, incx); + }); + } + } + + return done; +} + template inline sycl::event trmv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, const T *a, int64_t lda, T *x, int64_t incx, const std::vector &dependencies) { - throw unimplemented("blas", "trmv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::trmv(func, queue, new_uplo, new_trans, unit_diag, n, a, lda, x, incx, + dependencies); } #define TRMV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -2647,11 +3506,51 @@ TRMV_LAUNCHER_USM(std::complex, rocblas_ztrmv) #undef TRMV_LAUNCHER_USM +template +inline sycl::event trsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, + diag unit_diag, int64_t n, const std::complex *a, int64_t lda, + std::complex *x, int64_t incx, + const std::vector &dependencies) { + sycl::event done; + + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + queue.submit([&](sycl::handler &cgh) { conj_vector(cgh, x, n, incx); }) + .wait_and_throw(); + } + } + + done = column_major::trsv(func, queue, new_uplo, new_trans, unit_diag, n, a, lda, x, incx, + dependencies); + + if (trans == oneapi::mkl::transpose::conjtrans) { + if (n > 0) { + done = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + conj_vector(cgh, x, n, incx); + }); + } + } + + return done; +} + template inline sycl::event trsv(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, diag unit_diag, int64_t n, const T *a, int64_t lda, T *x, int64_t incx, const std::vector &dependencies) { - throw unimplemented("blas", "trsv", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::trsv(func, queue, new_uplo, new_trans, unit_diag, n, a, lda, x, incx, + dependencies); } #define TRSV_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ diff --git a/src/blas/backends/rocblas/rocblas_level3.cpp b/src/blas/backends/rocblas/rocblas_level3.cpp index 15034e8f4..ef739a88b 100644 --- a/src/blas/backends/rocblas/rocblas_level3.cpp +++ b/src/blas/backends/rocblas/rocblas_level3.cpp @@ -18,6 +18,7 @@ * limitations under the License. * **************************************************************************/ + #include "rocblas_helper.hpp" #include "rocblas_task.hpp" @@ -38,12 +39,14 @@ inline void gemm(Func func, sycl::queue &queue, transpose transa, transpose tran sycl::buffer &b, int64_t ldb, T beta, sycl::buffer &c, int64_t ldc) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); auto c_acc = c.template get_access(cgh); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); + auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); auto c_ = sc.get_mem(c_acc); @@ -71,22 +74,19 @@ GEMM_LAUNCHER(std::complex, rocblas_zgemm) #undef GEMM_LAUNCHER -void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, - float alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, - int64_t ldb, float beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "gemm", "for column_major layout"); -} - -template -inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, sycl::queue &queue, - transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, T_C alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, - T_C beta, sycl::buffer &c, int64_t ldc) { +template +inline void gemm_ex(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, COMPUTETYPE CT, + sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, T_S alpha, sycl::buffer &a, int64_t lda, + sycl::buffer &b, int64_t ldb, T_S beta, sycl::buffer &c, + int64_t ldc) { using rocDataType_A = typename RocEquivalentType::Type; using rocDataType_B = typename RocEquivalentType::Type; using rocDataType_C = typename RocEquivalentType::Type; + using rocDataType_S = typename RocEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -99,25 +99,33 @@ inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, s auto c_ = sc.get_mem(c_acc); rocblas_status err; ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, k, (rocDataType_C *)&alpha, - a_, DT_A, lda, b_, DT_B, ldb, (rocDataType_C *)&beta, c_, DT_C, - ldc, DT_C, rocblas_gemm_algo_standard); + get_rocblas_operation(transb), m, n, k, (rocDataType_S *)&alpha, + a_, DT_A, lda, b_, DT_B, ldb, (rocDataType_S *)&beta, c_, DT_C, + ldc, c_, DT_C, ldc, CT, rocblas_gemm_algo_standard, 0, 0); }); }); } -#define GEMM_EX_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, ROCBLAS_ROUTINE, ROCMDATATYPE_A, ROCMDATATYPE_B, \ - ROCMDATATYPE_C) \ +#define GEMM_EX_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S, ROCBLAS_ROUTINE, ROCMDATATYPE_A, \ + ROCMDATATYPE_B, ROCMDATATYPE_C, ROCMCOMPUTETYPE) \ void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE_C alpha, sycl::buffer &a, int64_t lda, \ - sycl::buffer &b, int64_t ldb, TYPE_C beta, sycl::buffer &c, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + sycl::buffer &b, int64_t ldb, TYPE_S beta, sycl::buffer &c, \ int64_t ldc) { \ - throw unimplemented("blas", "gemm", "half is disabled"); \ + gemm_ex(ROCBLAS_ROUTINE, ROCMDATATYPE_A, ROCMDATATYPE_B, ROCMDATATYPE_C, ROCMCOMPUTETYPE, \ + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } -GEMM_EX_LAUNCHER(sycl::half, sycl::half, float, rocblas_gemm_ex, HIP_R_16F, HIP_R_16F, HIP_R_32F) -GEMM_EX_LAUNCHER(sycl::half, sycl::half, sycl::half, rocblas_gemm_ex, HIP_R_16F, HIP_R_16F, - HIP_R_16F) +GEMM_EX_LAUNCHER(sycl::half, sycl::half, float, float, rocblas_gemm_ex, rocblas_datatype_f16_r, + rocblas_datatype_f16_r, rocblas_datatype_f32_r, rocblas_datatype_f32_r) +GEMM_EX_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half, rocblas_gemm_ex, + rocblas_datatype_f16_r, rocblas_datatype_f16_r, rocblas_datatype_f16_r, + rocblas_datatype_f16_r) + +GEMM_EX_LAUNCHER(bfloat16, bfloat16, float, float, rocblas_gemm_ex, rocblas_datatype_bf16_r, + rocblas_datatype_bf16_r, rocblas_datatype_f32_r, rocblas_datatype_f32_r) +GEMM_EX_LAUNCHER(bfloat16, bfloat16, bfloat16, float, rocblas_gemm_ex, rocblas_datatype_bf16_r, + rocblas_datatype_bf16_r, rocblas_datatype_bf16_r, rocblas_datatype_f32_r) #undef GEMM_EX_LAUNCHER @@ -127,6 +135,7 @@ inline void symm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe int64_t ldb, T beta, sycl::buffer &c, int64_t ldc) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -166,6 +175,7 @@ inline void hemm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe int64_t ldb, T beta, sycl::buffer &c, int64_t ldc) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -191,6 +201,7 @@ inline void hemm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe hemm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, \ c, ldc); \ } + HEMM_LAUNCHER(std::complex, rocblas_chemm) HEMM_LAUNCHER(std::complex, rocblas_zhemm) @@ -202,6 +213,7 @@ inline void syrk(Func func, sycl::queue &queue, uplo upper_lower, transpose tran sycl::buffer &c, int64_t ldc) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, ldc); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto c_acc = c.template get_access(cgh); @@ -239,6 +251,7 @@ inline void herk(Func func, sycl::queue &queue, uplo upper_lower, transpose tran using rocDataType = typename RocEquivalentType::Type; using rocScalarType = typename RocEquivalentType::Type; overflow_check(n, k, lda, ldc); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto c_acc = c.template get_access(cgh); @@ -273,6 +286,7 @@ inline void syr2k(Func func, sycl::queue &queue, uplo upper_lower, transpose tra int64_t ldb, T beta, sycl::buffer &c, int64_t ldc) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -298,6 +312,7 @@ inline void syr2k(Func func, sycl::queue &queue, uplo upper_lower, transpose tra syr2k(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, \ ldc); \ } + SYR2K_LAUNCHER(float, rocblas_ssyr2k) SYR2K_LAUNCHER(double, rocblas_dsyr2k) SYR2K_LAUNCHER(std::complex, rocblas_csyr2k) @@ -313,6 +328,7 @@ inline void her2k(Func func, sycl::queue &queue, uplo upper_lower, transpose tra using rocDataType = typename RocEquivalentType::Type; using rocScalarType = typename RocEquivalentType::Type; overflow_check(n, k, lda, ldb, ldc); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -355,6 +371,7 @@ inline void trmm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe sycl::buffer &b, int64_t ldb) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -364,10 +381,17 @@ inline void trmm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); rocblas_status err; +#if ROCBLAS_VERSION_MAJOR >= 4 + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), + get_rocblas_fill_mode(upper_lower), + get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag), + m, n, (rocDataType *)&alpha, a_, lda, b_, ldb, b_, ldb); +#else ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), get_rocblas_fill_mode(upper_lower), get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag), m, n, (rocDataType *)&alpha, a_, lda, b_, ldb); +#endif }); }); } @@ -379,6 +403,7 @@ inline void trmm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe trmm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, \ lda, b, ldb); \ } + TRMM_LAUNCHER(float, rocblas_strmm) TRMM_LAUNCHER(double, rocblas_dtrmm) TRMM_LAUNCHER(std::complex, rocblas_ctrmm) @@ -392,6 +417,7 @@ inline void trsm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe sycl::buffer &b, int64_t ldb) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); + queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto b_acc = b.template get_access(cgh); @@ -416,6 +442,7 @@ inline void trsm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe trsm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, \ lda, b, ldb); \ } + TRSM_LAUNCHER(float, rocblas_strsm) TRSM_LAUNCHER(double, rocblas_dtrsm) TRSM_LAUNCHER(std::complex, rocblas_ctrsm) @@ -432,11 +459,9 @@ inline sycl::event gemm(Func func, sycl::queue &queue, transpose transa, transpo const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -449,6 +474,7 @@ inline sycl::event gemm(Func func, sycl::queue &queue, transpose transa, transpo a_, lda, b_, ldb, (rocDataType *)&beta, c_, ldc); }); }); + return done; } @@ -467,71 +493,72 @@ GEMM_LAUNCHER_USM(std::complex, rocblas_cgemm) GEMM_LAUNCHER_USM(std::complex, rocblas_zgemm) #undef GEMM_LAUNCHER_USM -template -inline sycl::event gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, - sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, T_C alpha, const T_A *a, int64_t lda, const T_B *b, - int64_t ldb, T_C beta, T_C *c, int64_t ldc, - const std::vector &dependencies) { + +template +inline sycl::event gemm_ex(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, + COMPUTETYPE CT, sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, T_S alpha, const T_A *a, int64_t lda, + const T_B *b, int64_t ldb, T_S beta, T_C *c, int64_t ldc, + const std::vector &dependencies) { using rocDataType_A = typename RocEquivalentType::Type; using rocDataType_B = typename RocEquivalentType::Type; using rocDataType_C = typename RocEquivalentType::Type; + using rocDataType_S = typename RocEquivalentType::Type; overflow_check(m, n, k, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); + auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); auto c_ = reinterpret_cast(c); rocblas_status err; ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_operation(transa), - get_rocblas_operation(transb), m, n, k, (rocDataType_C *)&alpha, - a_, DT_A, lda, b_, DT_B, ldb, (rocDataType_C *)&beta, c_, DT_C, - ldc, c_, DT_C, ldc, DT_C, rocblas_gemm_algo_standard, 0, 0); + get_rocblas_operation(transb), m, n, k, (rocDataType_S *)&alpha, + a_, DT_A, lda, b_, DT_B, ldb, (rocDataType_S *)&beta, c_, DT_C, + ldc, c_, DT_C, ldc, CT, rocblas_gemm_algo_standard, 0, 0); }); }); + return done; } -#define GEMM_EX_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, ROCBLAS_ROUTINE, ROCMDATATYPE_A, \ - ROCMDATATYPE_B, ROCMDATATYPE_C) \ +#define GEMM_EX_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S, ROCBLAS_ROUTINE, ROCMDATATYPE_A, \ + ROCMDATATYPE_B, ROCMDATATYPE_C, ROCMCOMPUTETYPE) \ sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE_C alpha, const TYPE_A *a, int64_t lda, const TYPE_B *b, \ - int64_t ldb, TYPE_C beta, TYPE_C *c, int64_t ldc, \ + int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, const TYPE_B *b, \ + int64_t ldb, TYPE_S beta, TYPE_C *c, int64_t ldc, \ const std::vector &dependencies) { \ - return gemm(ROCBLAS_ROUTINE, ROCMDATATYPE_A, ROCMDATATYPE_B, ROCMDATATYPE_C, queue, \ - transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, dependencies); \ + return gemm_ex(ROCBLAS_ROUTINE, ROCMDATATYPE_A, ROCMDATATYPE_B, ROCMDATATYPE_C, \ + ROCMCOMPUTETYPE, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, \ + beta, c, ldc, dependencies); \ } -GEMM_EX_LAUNCHER_USM(sycl::half, sycl::half, float, rocblas_gemm_ex, rocblas_datatype_f16_r, - rocblas_datatype_f16_r, rocblas_datatype_f32_r) -GEMM_EX_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, rocblas_gemm_ex, rocblas_datatype_f16_r, - rocblas_datatype_f16_r, rocblas_datatype_f16_r) +GEMM_EX_LAUNCHER_USM(sycl::half, sycl::half, float, float, rocblas_gemm_ex, rocblas_datatype_f16_r, + rocblas_datatype_f16_r, rocblas_datatype_f32_r, rocblas_datatype_f32_r) +GEMM_EX_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half, rocblas_gemm_ex, + rocblas_datatype_f16_r, rocblas_datatype_f16_r, rocblas_datatype_f16_r, + rocblas_datatype_f16_r) + +GEMM_EX_LAUNCHER_USM(bfloat16, bfloat16, float, float, rocblas_gemm_ex, rocblas_datatype_bf16_r, + rocblas_datatype_bf16_r, rocblas_datatype_f32_r, rocblas_datatype_f32_r) +GEMM_EX_LAUNCHER_USM(bfloat16, bfloat16, bfloat16, float, rocblas_gemm_ex, rocblas_datatype_bf16_r, + rocblas_datatype_bf16_r, rocblas_datatype_bf16_r, rocblas_datatype_f32_r) #undef GEMM_EX_LAUNCHER_USM -sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, float alpha, const bfloat16 *a, int64_t lda, const bfloat16 *b, - int64_t ldb, float beta, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "gemm", "for column_major layout"); -} template inline sycl::event symm(Func func, sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *b, int64_t ldb, T beta, T *c, int64_t ldc, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -544,6 +571,7 @@ inline sycl::event symm(Func func, sycl::queue &queue, side left_right, uplo upp a_, lda, b_, ldb, (rocDataType *)&beta, c_, ldc); }); }); + return done; } @@ -569,11 +597,9 @@ inline sycl::event hemm(Func func, sycl::queue &queue, side left_right, uplo upp T beta, T *c, int64_t ldc, const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -586,6 +612,7 @@ inline sycl::event hemm(Func func, sycl::queue &queue, side left_right, uplo upp a_, lda, b_, ldb, (rocDataType *)&beta, c_, ldc); }); }); + return done; } @@ -597,6 +624,7 @@ inline sycl::event hemm(Func func, sycl::queue &queue, side left_right, uplo upp return hemm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, \ beta, c, ldc, dependencies); \ } + HEMM_LAUNCHER_USM(std::complex, rocblas_chemm) HEMM_LAUNCHER_USM(std::complex, rocblas_zhemm) @@ -608,11 +636,9 @@ inline sycl::event syrk(Func func, sycl::queue &queue, uplo upper_lower, transpo const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -624,6 +650,7 @@ inline sycl::event syrk(Func func, sycl::queue &queue, uplo upper_lower, transpo lda, (rocDataType *)&beta, c_, ldc); }); }); + return done; } @@ -650,11 +677,9 @@ inline sycl::event herk(Func func, sycl::queue &queue, uplo upper_lower, transpo using rocDataType = typename RocEquivalentType::Type; using rocScalarType = typename RocEquivalentType::Type; overflow_check(n, k, lda, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -666,6 +691,7 @@ inline sycl::event herk(Func func, sycl::queue &queue, uplo upper_lower, transpo lda, (rocScalarType *)&beta, c_, ldc); }); }); + return done; } @@ -690,11 +716,9 @@ inline sycl::event syr2k(Func func, sycl::queue &queue, uplo upper_lower, transp const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(n, k, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -707,6 +731,7 @@ inline sycl::event syr2k(Func func, sycl::queue &queue, uplo upper_lower, transp lda, b_, ldb, (rocDataType *)&beta, c_, ldc); }); }); + return done; } @@ -718,6 +743,7 @@ inline sycl::event syr2k(Func func, sycl::queue &queue, uplo upper_lower, transp return syr2k(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, \ beta, c, ldc, dependencies); \ } + SYR2K_LAUNCHER_USM(float, rocblas_ssyr2k) SYR2K_LAUNCHER_USM(double, rocblas_dsyr2k) SYR2K_LAUNCHER_USM(std::complex, rocblas_csyr2k) @@ -733,11 +759,9 @@ inline sycl::event her2k(Func func, sycl::queue &queue, uplo upper_lower, transp using rocDataType = typename RocEquivalentType::Type; using rocScalarType = typename RocEquivalentType::Type; overflow_check(n, k, lda, ldb, ldc); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -750,6 +774,7 @@ inline sycl::event her2k(Func func, sycl::queue &queue, uplo upper_lower, transp lda, b_, ldb, (rocScalarType *)&beta, c_, ldc); }); }); + return done; } @@ -778,23 +803,29 @@ inline sycl::event trmm(Func func, sycl::queue &queue, side left_right, uplo upp const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); rocblas_status err; +#if ROCBLAS_VERSION_MAJOR >= 4 + ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), + get_rocblas_fill_mode(upper_lower), + get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag), + m, n, (rocDataType *)&alpha, a_, lda, b_, ldb, b_, ldb); +#else ROCBLAS_ERROR_FUNC_SYNC(func, err, handle, get_rocblas_side_mode(left_right), get_rocblas_fill_mode(upper_lower), get_rocblas_operation(trans), get_rocblas_diag_type(unit_diag), m, n, (rocDataType *)&alpha, a_, lda, b_, ldb); +#endif }); }); + return done; } @@ -805,6 +836,7 @@ inline sycl::event trmm(Func func, sycl::queue &queue, side left_right, uplo upp return trmm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, \ alpha, a, lda, b, ldb, dependencies); \ } + TRMM_LAUNCHER_USM(float, rocblas_strmm) TRMM_LAUNCHER_USM(double, rocblas_dtrmm) TRMM_LAUNCHER_USM(std::complex, rocblas_ctrmm) @@ -819,11 +851,9 @@ inline sycl::event trsm(Func func, sycl::queue &queue, side left_right, uplo upp const std::vector &dependencies) { using rocDataType = typename RocEquivalentType::Type; overflow_check(m, n, lda, ldb); + auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_rocblas_host_task(cgh, queue, [=](RocblasScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -836,6 +866,7 @@ inline sycl::event trsm(Func func, sycl::queue &queue, side left_right, uplo upp m, n, (rocDataType *)&alpha, a_, lda, b_, ldb); }); }); + return done; } @@ -846,6 +877,7 @@ inline sycl::event trsm(Func func, sycl::queue &queue, side left_right, uplo upp return trsm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, \ alpha, a, lda, b, ldb, dependencies); \ } + TRSM_LAUNCHER_USM(float, rocblas_strsm) TRSM_LAUNCHER_USM(double, rocblas_dtrsm) TRSM_LAUNCHER_USM(std::complex, rocblas_ctrsm) @@ -854,6 +886,7 @@ TRSM_LAUNCHER_USM(std::complex, rocblas_ztrsm) #undef TRSM_LAUNCHER_USM } // namespace column_major + namespace row_major { // Buffer APIs @@ -862,7 +895,11 @@ template inline void gemm(Func func, sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, T beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "gemm", "for row_major layout"); + auto new_transa = transb; + auto new_transb = transa; + + column_major::gemm(func, queue, new_transa, new_transb, n, m, k, alpha, b, ldb, a, lda, beta, c, + ldc); } #define GEMM_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -881,40 +918,53 @@ GEMM_LAUNCHER(std::complex, rocblas_zgemm) #undef GEMM_LAUNCHER -template -inline void gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, sycl::queue &queue, - transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, T_C alpha, - sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, - T_C beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "gemm", "for row_major layout"); +template +inline void gemm_ex(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, COMPUTETYPE CT, + sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, + int64_t k, T_S alpha, sycl::buffer &a, int64_t lda, + sycl::buffer &b, int64_t ldb, T_S beta, sycl::buffer &c, + int64_t ldc) { + auto new_transa = transb; + auto new_transb = transa; + + column_major::gemm_ex(func, DT_A, DT_B, DT_C, CT, queue, new_transa, new_transb, n, m, k, alpha, + b, ldb, a, lda, beta, c, ldc); } -#define GEMM_EX_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, ROCBLAS_ROUTINE, ROCMDATATYPE_A, ROCMDATATYPE_B, \ - ROCMDATATYPE_C) \ +#define GEMM_EX_LAUNCHER(TYPE_A, TYPE_B, TYPE_C, TYPE_S, ROCBLAS_ROUTINE, ROCMDATATYPE_A, \ + ROCMDATATYPE_B, ROCMDATATYPE_C, ROCMCOMPUTETYPE) \ void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE_C alpha, sycl::buffer &a, int64_t lda, \ - sycl::buffer &b, int64_t ldb, TYPE_C beta, sycl::buffer &c, \ + int64_t k, TYPE_S alpha, sycl::buffer &a, int64_t lda, \ + sycl::buffer &b, int64_t ldb, TYPE_S beta, sycl::buffer &c, \ int64_t ldc) { \ - gemm(ROCBLAS_ROUTINE, ROCMDATATYPE_A, ROCMDATATYPE_B, ROCMDATATYPE_C, queue, transa, \ - transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ + gemm_ex(ROCBLAS_ROUTINE, ROCMDATATYPE_A, ROCMDATATYPE_B, ROCMDATATYPE_C, ROCMCOMPUTETYPE, \ + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); \ } -GEMM_EX_LAUNCHER(sycl::half, sycl::half, float, rocblas_gemm_ex, HIP_R_16F, HIP_R_16F, HIP_R_32F) -GEMM_EX_LAUNCHER(sycl::half, sycl::half, sycl::half, rocblas_gemm_ex, HIP_R_16F, HIP_R_16F, - HIP_R_16F) -#undef GEMM_EX_LAUNCHER -void gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, int64_t k, - float alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, - int64_t ldb, float beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "gemm", "for row_major layout"); -} +GEMM_EX_LAUNCHER(sycl::half, sycl::half, float, float, rocblas_gemm_ex, rocblas_datatype_f16_r, + rocblas_datatype_f16_r, rocblas_datatype_f32_r, rocblas_datatype_f32_r) +GEMM_EX_LAUNCHER(sycl::half, sycl::half, sycl::half, sycl::half, rocblas_gemm_ex, + rocblas_datatype_f16_r, rocblas_datatype_f16_r, rocblas_datatype_f16_r, + rocblas_datatype_f16_r) + +GEMM_EX_LAUNCHER(bfloat16, bfloat16, float, float, rocblas_gemm_ex, rocblas_datatype_bf16_r, + rocblas_datatype_bf16_r, rocblas_datatype_f32_r, rocblas_datatype_f32_r) +GEMM_EX_LAUNCHER(bfloat16, bfloat16, bfloat16, float, rocblas_gemm_ex, rocblas_datatype_bf16_r, + rocblas_datatype_bf16_r, rocblas_datatype_bf16_r, rocblas_datatype_f32_r) + +#undef GEMM_EX_LAUNCHER template inline void symm(Func func, sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, T beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "symm", "for row_major layout"); + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::symm(func, queue, new_side, new_uplo, n, m, alpha, a, lda, b, ldb, beta, c, ldc); } #define SYMM_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -936,7 +986,12 @@ template inline void hemm(Func func, sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, T beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "hemm", "for row_major layout"); + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::hemm(func, queue, new_side, new_uplo, n, m, alpha, a, lda, b, ldb, beta, c, ldc); } #define HEMM_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -946,6 +1001,7 @@ inline void hemm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe hemm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, beta, \ c, ldc); \ } + HEMM_LAUNCHER(std::complex, rocblas_chemm) HEMM_LAUNCHER(std::complex, rocblas_zhemm) @@ -955,7 +1011,12 @@ template inline void syrk(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, T beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "syrk", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::syrk(func, queue, new_uplo, new_trans, n, k, alpha, a, lda, beta, c, ldc); } #define SYRK_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -976,7 +1037,12 @@ template inline void herk(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, ScalarType alpha, sycl::buffer &a, int64_t lda, ScalarType beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "herk", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::conjtrans + : oneapi::mkl::transpose::nontrans; + + column_major::herk(func, queue, new_uplo, new_trans, n, k, alpha, a, lda, beta, c, ldc); } #define HERK_LAUNCHER(DATA_TYPE, SCALAR_TYPE, ROCBLAS_ROUTINE) \ @@ -995,7 +1061,13 @@ template inline void syr2k(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, T beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "syr2k", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + column_major::syr2k(func, queue, new_uplo, new_trans, n, k, alpha, a, lda, b, ldb, beta, c, + ldc); } #define SYR2K_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1005,6 +1077,7 @@ inline void syr2k(Func func, sycl::queue &queue, uplo upper_lower, transpose tra syr2k(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, beta, c, \ ldc); \ } + SYR2K_LAUNCHER(float, rocblas_ssyr2k) SYR2K_LAUNCHER(double, rocblas_dsyr2k) SYR2K_LAUNCHER(std::complex, rocblas_csyr2k) @@ -1017,7 +1090,14 @@ inline void her2k(Func func, sycl::queue &queue, uplo upper_lower, transpose tra int64_t k, DataType alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb, ScalarType beta, sycl::buffer &c, int64_t ldc) { - throw unimplemented("blas", "her2k", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::conjtrans + : oneapi::mkl::transpose::nontrans; + auto new_alpha = std::conj(alpha); + + column_major::her2k(func, queue, new_uplo, new_trans, n, k, new_alpha, a, lda, b, ldb, beta, c, + ldc); } #define HER2K_LAUNCHER(DATA_TYPE, SCALAR_TYPE, ROCBLAS_ROUTINE) \ @@ -1042,7 +1122,13 @@ template inline void trmm(Func func, sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "trmm", "for row_major layout"); + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::trmm(func, queue, new_side, new_uplo, trans, unit_diag, n, m, alpha, a, lda, b, + ldb); } #define TRMM_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1052,6 +1138,7 @@ inline void trmm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe trmm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, \ lda, b, ldb); \ } + TRMM_LAUNCHER(float, rocblas_strmm) TRMM_LAUNCHER(double, rocblas_dtrmm) TRMM_LAUNCHER(std::complex, rocblas_ctrmm) @@ -1063,7 +1150,13 @@ template inline void trsm(Func func, sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, sycl::buffer &a, int64_t lda, sycl::buffer &b, int64_t ldb) { - throw unimplemented("blas", "trsm", "for row_major layout"); + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + column_major::trsm(func, queue, new_side, new_uplo, trans, unit_diag, n, m, alpha, a, lda, b, + ldb); } #define TRSM_LAUNCHER(TYPE, ROCBLAS_ROUTINE) \ @@ -1073,6 +1166,7 @@ inline void trsm(Func func, sycl::queue &queue, side left_right, uplo upper_lowe trsm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, alpha, a, \ lda, b, ldb); \ } + TRSM_LAUNCHER(float, rocblas_strsm) TRSM_LAUNCHER(double, rocblas_dtrsm) TRSM_LAUNCHER(std::complex, rocblas_ctrsm) @@ -1087,7 +1181,11 @@ inline sycl::event gemm(Func func, sycl::queue &queue, transpose transa, transpo int64_t m, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, const T *b, int64_t ldb, T beta, T *c, int64_t ldc, const std::vector &dependencies) { - throw unimplemented("blas", "gemm", "for row_major layout"); + auto new_transa = transb; + auto new_transb = transa; + + return column_major::gemm(func, queue, new_transa, new_transb, n, m, k, alpha, b, ldb, a, lda, + beta, c, ldc, dependencies); } #define GEMM_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1105,44 +1203,56 @@ GEMM_LAUNCHER_USM(std::complex, rocblas_cgemm) GEMM_LAUNCHER_USM(std::complex, rocblas_zgemm) #undef GEMM_LAUNCHER_USM -template -inline sycl::event gemm(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, - sycl::queue &queue, transpose transa, transpose transb, int64_t m, - int64_t n, int64_t k, T_C alpha, const T_A *a, int64_t lda, const T_B *b, - int64_t ldb, T_C beta, T_C *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "gemm", "for row_major layout"); + +template +inline sycl::event gemm_ex(Func func, DATATYPE_A DT_A, DATATYPE_B DT_B, DATATYPE_C DT_C, + COMPUTETYPE CT, sycl::queue &queue, transpose transa, transpose transb, + int64_t m, int64_t n, int64_t k, T_S alpha, const T_A *a, int64_t lda, + const T_B *b, int64_t ldb, T_S beta, T_C *c, int64_t ldc, + const std::vector &dependencies) { + auto new_transa = transb; + auto new_transb = transa; + + return column_major::gemm_ex(func, DT_A, DT_B, DT_C, CT, queue, new_transa, new_transb, n, m, k, + alpha, b, ldb, a, lda, beta, c, ldc, dependencies); } -#define GEMM_EX_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, ROCBLAS_ROUTINE, ROCMDATATYPE_A, \ - ROCMDATATYPE_B, ROCMDATATYPE_C) \ +#define GEMM_EX_LAUNCHER_USM(TYPE_A, TYPE_B, TYPE_C, TYPE_S, ROCBLAS_ROUTINE, ROCMDATATYPE_A, \ + ROCMDATATYPE_B, ROCMDATATYPE_C, ROCMCOMPUTETYPE) \ sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, \ - int64_t k, TYPE_C alpha, const TYPE_A *a, int64_t lda, const TYPE_B *b, \ - int64_t ldb, TYPE_C beta, TYPE_C *c, int64_t ldc, \ + int64_t k, TYPE_S alpha, const TYPE_A *a, int64_t lda, const TYPE_B *b, \ + int64_t ldb, TYPE_S beta, TYPE_C *c, int64_t ldc, \ const std::vector &dependencies) { \ - return gemm(ROCBLAS_ROUTINE, ROCMDATATYPE_A, ROCMDATATYPE_B, ROCMDATATYPE_C, queue, \ - transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, dependencies); \ + return gemm_ex(ROCBLAS_ROUTINE, ROCMDATATYPE_A, ROCMDATATYPE_B, ROCMDATATYPE_C, \ + ROCMCOMPUTETYPE, queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, \ + beta, c, ldc, dependencies); \ } -GEMM_EX_LAUNCHER_USM(sycl::half, sycl::half, float, rocblas_gemm_ex, HIP_R_16F, HIP_R_16F, - HIP_R_32F) -GEMM_EX_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, rocblas_gemm_ex, HIP_R_16F, HIP_R_16F, - HIP_R_16F) +GEMM_EX_LAUNCHER_USM(sycl::half, sycl::half, float, float, rocblas_gemm_ex, rocblas_datatype_f16_r, + rocblas_datatype_f16_r, rocblas_datatype_f32_r, rocblas_datatype_f32_r) +GEMM_EX_LAUNCHER_USM(sycl::half, sycl::half, sycl::half, sycl::half, rocblas_gemm_ex, + rocblas_datatype_f16_r, rocblas_datatype_f16_r, rocblas_datatype_f16_r, + rocblas_datatype_f16_r) + +GEMM_EX_LAUNCHER_USM(bfloat16, bfloat16, float, float, rocblas_gemm_ex, rocblas_datatype_bf16_r, + rocblas_datatype_bf16_r, rocblas_datatype_f32_r, rocblas_datatype_f32_r) +GEMM_EX_LAUNCHER_USM(bfloat16, bfloat16, bfloat16, float, rocblas_gemm_ex, rocblas_datatype_bf16_r, + rocblas_datatype_bf16_r, rocblas_datatype_bf16_r, rocblas_datatype_f32_r) #undef GEMM_EX_LAUNCHER_USM -sycl::event gemm(sycl::queue &queue, transpose transa, transpose transb, int64_t m, int64_t n, - int64_t k, float alpha, const bfloat16 *a, int64_t lda, const bfloat16 *b, - int64_t ldb, float beta, float *c, int64_t ldc, - const std::vector &dependencies) { - throw unimplemented("blas", "gemm", "for row_major layout"); -} template inline sycl::event symm(Func func, sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *b, int64_t ldb, T beta, T *c, int64_t ldc, const std::vector &dependencies) { - throw unimplemented("blas", "symm", "for row_major layout"); + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::symm(func, queue, new_side, new_uplo, n, m, alpha, a, lda, b, ldb, beta, c, + ldc, dependencies); } #define SYMM_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1165,7 +1275,13 @@ template inline sycl::event hemm(Func func, sycl::queue &queue, side left_right, uplo upper_lower, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, const T *b, int64_t ldb, T beta, T *c, int64_t ldc, const std::vector &dependencies) { - throw unimplemented("blas", "hemm", "for row_major layout"); + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::hemm(func, queue, new_side, new_uplo, n, m, alpha, a, lda, b, ldb, beta, c, + ldc, dependencies); } #define HEMM_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1176,6 +1292,7 @@ inline sycl::event hemm(Func func, sycl::queue &queue, side left_right, uplo upp return hemm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, m, n, alpha, a, lda, b, ldb, \ beta, c, ldc, dependencies); \ } + HEMM_LAUNCHER_USM(std::complex, rocblas_chemm) HEMM_LAUNCHER_USM(std::complex, rocblas_zhemm) @@ -1185,7 +1302,13 @@ template inline sycl::event syrk(Func func, sycl::queue &queue, uplo upper_lower, transpose trans, int64_t n, int64_t k, T alpha, const T *a, int64_t lda, T beta, T *c, int64_t ldc, const std::vector &dependencies) { - throw unimplemented("blas", "syrk", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::syrk(func, queue, new_uplo, new_trans, n, k, alpha, a, lda, beta, c, ldc, + dependencies); } #define SYRK_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1208,7 +1331,13 @@ inline sycl::event herk(Func func, sycl::queue &queue, uplo upper_lower, transpo int64_t k, const ScalarType alpha, const DataType *a, int64_t lda, const ScalarType beta, DataType *c, int64_t ldc, const std::vector &dependencies) { - throw unimplemented("blas", "herk", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::conjtrans + : oneapi::mkl::transpose::nontrans; + + return column_major::herk(func, queue, new_uplo, new_trans, n, k, alpha, a, lda, beta, c, ldc, + dependencies); } #define HERK_LAUNCHER_USM(DATA_TYPE, SCALAR_TYPE, ROCBLAS_ROUTINE) \ @@ -1230,7 +1359,13 @@ inline sycl::event syr2k(Func func, sycl::queue &queue, uplo upper_lower, transp int64_t n, int64_t k, T alpha, const T *a, int64_t lda, const T *b, int64_t ldb, T beta, T *c, int64_t ldc, const std::vector &dependencies) { - throw unimplemented("blas", "syr2k", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::trans + : oneapi::mkl::transpose::nontrans; + + return column_major::syr2k(func, queue, new_uplo, new_trans, n, k, alpha, a, lda, b, ldb, beta, + c, ldc, dependencies); } #define SYR2K_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1241,6 +1376,7 @@ inline sycl::event syr2k(Func func, sycl::queue &queue, uplo upper_lower, transp return syr2k(ROCBLAS_ROUTINE, queue, upper_lower, trans, n, k, alpha, a, lda, b, ldb, \ beta, c, ldc, dependencies); \ } + SYR2K_LAUNCHER_USM(float, rocblas_ssyr2k) SYR2K_LAUNCHER_USM(double, rocblas_dsyr2k) SYR2K_LAUNCHER_USM(std::complex, rocblas_csyr2k) @@ -1253,7 +1389,14 @@ inline sycl::event her2k(Func func, sycl::queue &queue, uplo upper_lower, transp int64_t n, int64_t k, const DataType alpha, const DataType *a, int64_t lda, const DataType *b, int64_t ldb, const ScalarType beta, DataType *c, int64_t ldc, const std::vector &dependencies) { - throw unimplemented("blas", "her2k", "for row_major layout"); + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + auto new_trans = trans == oneapi::mkl::transpose::nontrans ? oneapi::mkl::transpose::conjtrans + : oneapi::mkl::transpose::nontrans; + auto new_alpha = std::conj(alpha); + + return column_major::her2k(func, queue, new_uplo, new_trans, n, k, new_alpha, a, lda, b, ldb, + beta, c, ldc, dependencies); } #define HER2K_LAUNCHER_USM(DATA_TYPE, SCALAR_TYPE, ROCBLAS_ROUTINE) \ @@ -1279,7 +1422,13 @@ inline sycl::event trmm(Func func, sycl::queue &queue, side left_right, uplo upp transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, T *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "trmm", "for row_major layout"); + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::trmm(func, queue, new_side, new_uplo, trans, unit_diag, n, m, alpha, a, + lda, b, ldb, dependencies); } #define TRMM_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1289,6 +1438,7 @@ inline sycl::event trmm(Func func, sycl::queue &queue, side left_right, uplo upp return trmm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, \ alpha, a, lda, b, ldb, dependencies); \ } + TRMM_LAUNCHER_USM(float, rocblas_strmm) TRMM_LAUNCHER_USM(double, rocblas_dtrmm) TRMM_LAUNCHER_USM(std::complex, rocblas_ctrmm) @@ -1301,7 +1451,13 @@ inline sycl::event trsm(Func func, sycl::queue &queue, side left_right, uplo upp transpose trans, diag unit_diag, int64_t m, int64_t n, T alpha, const T *a, int64_t lda, T *b, int64_t ldb, const std::vector &dependencies) { - throw unimplemented("blas", "trsm", "for row_major layout"); + auto new_side = + left_right == oneapi::mkl::side::left ? oneapi::mkl::side::right : oneapi::mkl::side::left; + auto new_uplo = upper_lower == oneapi::mkl::uplo::lower ? oneapi::mkl::uplo::upper + : oneapi::mkl::uplo::lower; + + return column_major::trsm(func, queue, new_side, new_uplo, trans, unit_diag, n, m, alpha, a, + lda, b, ldb, dependencies); } #define TRSM_LAUNCHER_USM(TYPE, ROCBLAS_ROUTINE) \ @@ -1311,6 +1467,7 @@ inline sycl::event trsm(Func func, sycl::queue &queue, side left_right, uplo upp return trsm(ROCBLAS_ROUTINE, queue, left_right, upper_lower, trans, unit_diag, m, n, \ alpha, a, lda, b, ldb, dependencies); \ } + TRSM_LAUNCHER_USM(float, rocblas_strsm) TRSM_LAUNCHER_USM(double, rocblas_dtrsm) TRSM_LAUNCHER_USM(std::complex, rocblas_ctrsm) diff --git a/src/blas/backends/rocblas/rocblas_scope_handle.cpp b/src/blas/backends/rocblas/rocblas_scope_handle.cpp index 2abf0323b..404d1fc06 100644 --- a/src/blas/backends/rocblas/rocblas_scope_handle.cpp +++ b/src/blas/backends/rocblas/rocblas_scope_handle.cpp @@ -58,10 +58,11 @@ RocblasScopedContextHandler::RocblasScopedContextHandler(sycl::queue queue, : interop_h(ih), needToRecover_(false) { placedContext_ = new sycl::context(queue.get_context()); - auto device = queue.get_device(); - auto desired = sycl::get_native(*placedContext_); + auto hipDevice = ih.get_native_device(); hipError_t err; + hipCtx_t desired; HIP_ERROR_FUNC(hipCtxGetCurrent, err, &original_); + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice); if (original_ != desired) { // Sets the desired context as the active one for the thread HIP_ERROR_FUNC(hipCtxSetCurrent, err, desired); @@ -103,8 +104,11 @@ void ContextCallback(void *userData) { } rocblas_handle RocblasScopedContextHandler::get_handle(const sycl::queue &queue) { - auto piPlacedContext_ = reinterpret_cast( - sycl::get_native(*placedContext_)); + auto hipDevice = interop_h.get_native_device(); + hipError_t hipErr; + hipCtx_t desired; + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, hipErr, &desired, hipDevice); + auto piPlacedContext_ = reinterpret_cast(desired); hipStream_t streamId = get_stream(queue); rocblas_status err; auto it = handle_helper.rocblas_handle_container_mapper_.find(piPlacedContext_); diff --git a/src/blas/backends/rocblas/rocblas_task.hpp b/src/blas/backends/rocblas/rocblas_task.hpp index 13686e9e4..94e2b2b4a 100644 --- a/src/blas/backends/rocblas/rocblas_task.hpp +++ b/src/blas/backends/rocblas/rocblas_task.hpp @@ -20,7 +20,7 @@ **************************************************************************/ #ifndef _ROCBLAS_TASK_HPP_ #define _ROCBLAS_TASK_HPP_ -#include +#include #include #if __has_include() #include diff --git a/src/blas/backends/rocblas/rocblas_wrappers.cpp b/src/blas/backends/rocblas/rocblas_wrappers.cpp index 181c8d9d1..ce4c92da5 100644 --- a/src/blas/backends/rocblas/rocblas_wrappers.cpp +++ b/src/blas/backends/rocblas/rocblas_wrappers.cpp @@ -207,6 +207,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::trsm_batch, oneapi::mkl::blas::rocblas::column_major::trsm_batch, oneapi::mkl::blas::rocblas::column_major::trsm_batch, @@ -235,6 +238,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::omatcopy, oneapi::mkl::blas::rocblas::column_major::omatcopy, oneapi::mkl::blas::rocblas::column_major::omatcopy, + oneapi::mkl::blas::rocblas::column_major::omatcopy2, + oneapi::mkl::blas::rocblas::column_major::omatcopy2, + oneapi::mkl::blas::rocblas::column_major::omatcopy2, + oneapi::mkl::blas::rocblas::column_major::omatcopy2, oneapi::mkl::blas::rocblas::column_major::imatcopy, oneapi::mkl::blas::rocblas::column_major::imatcopy, oneapi::mkl::blas::rocblas::column_major::imatcopy, @@ -458,6 +465,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, + oneapi::mkl::blas::rocblas::column_major::gemm_batch, oneapi::mkl::blas::rocblas::column_major::gemmt, oneapi::mkl::blas::rocblas::column_major::gemmt, oneapi::mkl::blas::rocblas::column_major::gemmt, @@ -482,6 +495,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::column_major::omatcopy, oneapi::mkl::blas::rocblas::column_major::omatcopy, oneapi::mkl::blas::rocblas::column_major::omatcopy, + oneapi::mkl::blas::rocblas::column_major::omatcopy2, + oneapi::mkl::blas::rocblas::column_major::omatcopy2, + oneapi::mkl::blas::rocblas::column_major::omatcopy2, + oneapi::mkl::blas::rocblas::column_major::omatcopy2, oneapi::mkl::blas::rocblas::column_major::imatcopy, oneapi::mkl::blas::rocblas::column_major::imatcopy, oneapi::mkl::blas::rocblas::column_major::imatcopy, @@ -680,6 +697,9 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::trsm_batch, oneapi::mkl::blas::rocblas::row_major::trsm_batch, oneapi::mkl::blas::rocblas::row_major::trsm_batch, @@ -708,6 +728,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::omatcopy, oneapi::mkl::blas::rocblas::row_major::omatcopy, oneapi::mkl::blas::rocblas::row_major::omatcopy, + oneapi::mkl::blas::rocblas::row_major::omatcopy2, + oneapi::mkl::blas::rocblas::row_major::omatcopy2, + oneapi::mkl::blas::rocblas::row_major::omatcopy2, + oneapi::mkl::blas::rocblas::row_major::omatcopy2, oneapi::mkl::blas::rocblas::row_major::imatcopy, oneapi::mkl::blas::rocblas::row_major::imatcopy, oneapi::mkl::blas::rocblas::row_major::imatcopy, @@ -931,6 +955,12 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, + oneapi::mkl::blas::rocblas::row_major::gemm_batch, oneapi::mkl::blas::rocblas::row_major::gemmt, oneapi::mkl::blas::rocblas::row_major::gemmt, oneapi::mkl::blas::rocblas::row_major::gemmt, @@ -955,6 +985,10 @@ extern "C" blas_function_table_t mkl_blas_table = { oneapi::mkl::blas::rocblas::row_major::omatcopy, oneapi::mkl::blas::rocblas::row_major::omatcopy, oneapi::mkl::blas::rocblas::row_major::omatcopy, + oneapi::mkl::blas::rocblas::row_major::omatcopy2, + oneapi::mkl::blas::rocblas::row_major::omatcopy2, + oneapi::mkl::blas::rocblas::row_major::omatcopy2, + oneapi::mkl::blas::rocblas::row_major::omatcopy2, oneapi::mkl::blas::rocblas::row_major::imatcopy, oneapi::mkl::blas::rocblas::row_major::imatcopy, oneapi::mkl::blas::rocblas::row_major::imatcopy, diff --git a/src/blas/blas_loader.cpp b/src/blas/blas_loader.cpp index a27db18cd..c1f1339c6 100644 --- a/src/blas/blas_loader.cpp +++ b/src/blas/blas_loader.cpp @@ -1342,6 +1342,39 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa stride_c, batch_size); } +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].column_major_gemm_f16f16f32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].column_major_gemm_s8s8f32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].column_major_gemm_s8s8s32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + void trsm_batch(oneapi::mkl::device libkey, sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -1585,6 +1618,38 @@ void omatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, s function_tables[libkey].column_major_zomatcopy_sycl(queue, trans, m, n, alpha, a, lda, b, ldb); } +void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + function_tables[libkey].column_major_somatcopy2_sycl(queue, trans, m, n, alpha, a, lda, stridea, + b, ldb, strideb); +} + +void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + function_tables[libkey].column_major_domatcopy2_sycl(queue, trans, m, n, alpha, a, lda, stridea, + b, ldb, strideb); +} + +void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + function_tables[libkey].column_major_comatcopy2_sycl(queue, trans, m, n, alpha, a, lda, stridea, + b, ldb, strideb); +} + +void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + function_tables[libkey].column_major_zomatcopy2_sycl(queue, trans, m, n, alpha, a, lda, stridea, + b, ldb, strideb); +} + void imatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb) { @@ -3373,6 +3438,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose group_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_f16f16f32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_s8s8f32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_s8s8s32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -3431,6 +3529,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose stride_c, batch_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_f16f16f32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_s8s8f32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].column_major_gemm_s8s8s32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + sycl::event gemmt(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, std::int64_t ldb, float beta, @@ -3662,6 +3793,40 @@ sycl::event omatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose t lda, b, ldb, dependencies); } +sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + return function_tables[libkey].column_major_somatcopy2_usm_sycl( + queue, trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); +} + +sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stridea, double *b, std::int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { + return function_tables[libkey].column_major_domatcopy2_usm_sycl( + queue, trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); +} + +sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stridea, + std::complex *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + return function_tables[libkey].column_major_comatcopy2_usm_sycl( + queue, trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); +} + +sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stridea, + std::complex *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + return function_tables[libkey].column_major_zomatcopy2_usm_sycl( + queue, trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); +} + sycl::event imatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, const std::vector &dependencies) { @@ -5111,6 +5276,39 @@ void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa stride_c, batch_size); } +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].row_major_gemm_f16f16f32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size) { + function_tables[libkey].row_major_gemm_s8s8f32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + +void gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, + float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size) { + function_tables[libkey].row_major_gemm_s8s8s32_batch_strided_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size); +} + void trsm_batch(oneapi::mkl::device libkey, sycl::queue &queue, side left_right, uplo upper_lower, transpose trans, diag unit_diag, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, @@ -5354,6 +5552,38 @@ void omatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, s function_tables[libkey].row_major_zomatcopy_sycl(queue, trans, m, n, alpha, a, lda, b, ldb); } +void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, float alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + function_tables[libkey].row_major_somatcopy2_sycl(queue, trans, m, n, alpha, a, lda, stridea, b, + ldb, strideb); +} + +void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, double alpha, sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, std::int64_t ldb, + std::int64_t strideb) { + function_tables[libkey].row_major_domatcopy2_sycl(queue, trans, m, n, alpha, a, lda, stridea, b, + ldb, strideb); +} + +void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + function_tables[libkey].row_major_comatcopy2_sycl(queue, trans, m, n, alpha, a, lda, stridea, b, + ldb, strideb); +} + +void omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, + std::int64_t lda, std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb) { + function_tables[libkey].row_major_zomatcopy2_sycl(queue, trans, m, n, alpha, a, lda, stridea, b, + ldb, strideb); +} + void imatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, std::int64_t ldb) { @@ -7138,6 +7368,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose group_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const sycl::half **a, std::int64_t *lda, const sycl::half **b, + std::int64_t *ldb, float *beta, float **c, std::int64_t *ldc, + std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_f16f16f32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_s8s8f32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose *transa, + transpose *transb, std::int64_t *m, std::int64_t *n, std::int64_t *k, + float *alpha, const std::int8_t **a, std::int64_t *lda, + const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_s8s8s32_batch_group_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, group_count, + group_size, dependencies); +} + sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, std::int64_t stride_a, @@ -7196,6 +7459,39 @@ sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose stride_c, batch_size, dependencies); } +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const sycl::half *a, std::int64_t lda, std::int64_t stride_a, + const sycl::half *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_f16f16f32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + float *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, + const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_s8s8f32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + +sycl::event gemm_batch(oneapi::mkl::device libkey, sycl::queue &queue, transpose transa, + transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, + float alpha, const std::int8_t *a, std::int64_t lda, std::int64_t stride_a, + const std::int8_t *b, std::int64_t ldb, std::int64_t stride_b, float beta, + std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies) { + return function_tables[libkey].row_major_gemm_s8s8s32_batch_strided_usm_sycl( + queue, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb, stride_b, beta, c, ldc, + stride_c, batch_size, dependencies); +} + sycl::event gemmt(oneapi::mkl::device libkey, sycl::queue &queue, uplo upper_lower, transpose transa, transpose transb, std::int64_t n, std::int64_t k, float alpha, const float *a, std::int64_t lda, const float *b, std::int64_t ldb, float beta, @@ -7427,6 +7723,40 @@ sycl::event omatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose t b, ldb, dependencies); } +sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, float alpha, const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + return function_tables[libkey].row_major_somatcopy2_usm_sycl( + queue, trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); +} + +sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, double alpha, const double *a, + std::int64_t lda, std::int64_t stridea, double *b, std::int64_t ldb, + std::int64_t strideb, const std::vector &dependencies) { + return function_tables[libkey].row_major_domatcopy2_usm_sycl( + queue, trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); +} + +sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stridea, + std::complex *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + return function_tables[libkey].row_major_comatcopy2_usm_sycl( + queue, trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); +} + +sycl::event omatcopy2(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, std::int64_t stridea, + std::complex *b, std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies) { + return function_tables[libkey].row_major_zomatcopy2_usm_sycl( + queue, trans, m, n, alpha, a, lda, stridea, b, ldb, strideb, dependencies); +} + sycl::event imatcopy(oneapi::mkl::device libkey, sycl::queue &queue, transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, const std::vector &dependencies) { diff --git a/src/blas/function_table.hpp b/src/blas/function_table.hpp index c8289fcf0..a242fd0c0 100644 --- a/src/blas/function_table.hpp +++ b/src/blas/function_table.hpp @@ -869,6 +869,26 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + void (*column_major_gemm_f16f16f32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + void (*column_major_gemm_s8s8f32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*column_major_gemm_s8s8s32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); void (*column_major_strsm_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, @@ -1030,6 +1050,28 @@ typedef struct { std::int64_t m, std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, sycl::buffer, 1> &b, std::int64_t ldb); + void (*column_major_somatcopy2_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, + std::int64_t ldb, std::int64_t strideb); + void (*column_major_domatcopy2_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, + std::int64_t ldb, std::int64_t strideb); + void (*column_major_comatcopy2_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t strideb); + void (*column_major_zomatcopy2_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t strideb); void (*column_major_simatcopy_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, @@ -2158,6 +2200,24 @@ typedef struct { std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); + sycl::event (*column_major_gemm_f16f16f32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*column_major_gemm_s8s8f32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*column_major_gemm_s8s8s32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); sycl::event (*column_major_sgemm_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, @@ -2191,6 +2251,24 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_gemm_f16f16f32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_gemm_s8s8f32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*column_major_gemm_s8s8s32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); sycl::event (*column_major_sgemmt_usm_sycl)(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t n, @@ -2326,6 +2404,28 @@ typedef struct { const std::complex *a, std::int64_t lda, std::complex *b, std::int64_t ldb, const std::vector &dependencies); + sycl::event (*column_major_somatcopy2_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + float alpha, const float *a, std::int64_t lda, std::int64_t stridea, float *b, + std::int64_t ldb, std::int64_t strideb, const std::vector &dependencies); + sycl::event (*column_major_domatcopy2_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, + double alpha, const double *a, std::int64_t lda, std::int64_t stridea, double *b, + std::int64_t ldb, std::int64_t strideb, const std::vector &dependencies); + sycl::event (*column_major_comatcopy2_usm_sycl)(sycl::queue &queue, + oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies); + sycl::event (*column_major_zomatcopy2_usm_sycl)(sycl::queue &queue, + oneapi::mkl::transpose trans, std::int64_t m, + std::int64_t n, std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies); sycl::event (*column_major_simatcopy_usm_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, @@ -3225,6 +3325,26 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size); + void (*row_major_gemm_f16f16f32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, sycl::buffer &a, + std::int64_t lda, std::int64_t stride_a, sycl::buffer &b, std::int64_t ldb, + std::int64_t stride_b, float beta, sycl::buffer &c, std::int64_t ldc, + std::int64_t stride_c, std::int64_t batch_size); + void (*row_major_gemm_s8s8f32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); + void (*row_major_gemm_s8s8s32_batch_strided_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &b, std::int64_t ldb, std::int64_t stride_b, float beta, + sycl::buffer &c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size); void (*row_major_strsm_batch_strided_sycl)( sycl::queue &queue, oneapi::mkl::side left_right, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose trans, oneapi::mkl::diag unit_diag, std::int64_t m, std::int64_t n, @@ -3386,6 +3506,27 @@ typedef struct { std::int64_t m, std::int64_t n, std::complex alpha, sycl::buffer, 1> &a, std::int64_t lda, sycl::buffer, 1> &b, std::int64_t ldb); + void (*row_major_somatcopy2_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, float alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, + std::int64_t ldb, std::int64_t strideb); + void (*row_major_domatcopy2_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, double alpha, + sycl::buffer &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer &b, + std::int64_t ldb, std::int64_t strideb); + void (*row_major_comatcopy2_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, sycl::buffer, 1> &b, + std::int64_t ldb, std::int64_t strideb); + void (*row_major_zomatcopy2_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, std::complex alpha, + sycl::buffer, 1> &a, std::int64_t lda, + std::int64_t stridea, + sycl::buffer, 1> &b, std::int64_t ldb, + std::int64_t strideb); void (*row_major_simatcopy_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, float alpha, sycl::buffer &ab, std::int64_t lda, @@ -4516,6 +4657,24 @@ typedef struct { std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, sycl::half *beta, sycl::half **c, std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, const std::vector &dependencies); + sycl::event (*row_major_gemm_f16f16f32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const sycl::half **a, + std::int64_t *lda, const sycl::half **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*row_major_gemm_s8s8f32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, float **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); + sycl::event (*row_major_gemm_s8s8s32_batch_group_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose *transa, oneapi::mkl::transpose *transb, + std::int64_t *m, std::int64_t *n, std::int64_t *k, float *alpha, const std::int8_t **a, + std::int64_t *lda, const std::int8_t **b, std::int64_t *ldb, float *beta, std::int32_t **c, + std::int64_t *ldc, std::int64_t group_count, std::int64_t *group_size, + const std::vector &dependencies); sycl::event (*row_major_sgemm_batch_strided_usm_sycl)( sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const float *a, @@ -4549,6 +4708,24 @@ typedef struct { std::int64_t stride_b, sycl::half beta, sycl::half *c, std::int64_t ldc, std::int64_t stride_c, std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_gemm_f16f16f32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const sycl::half *a, + std::int64_t lda, std::int64_t stride_a, const sycl::half *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_gemm_s8s8f32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, float *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); + sycl::event (*row_major_gemm_s8s8s32_batch_strided_usm_sycl)( + sycl::queue &queue, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, + std::int64_t m, std::int64_t n, std::int64_t k, float alpha, const std::int8_t *a, + std::int64_t lda, std::int64_t stride_a, const std::int8_t *b, std::int64_t ldb, + std::int64_t stride_b, float beta, std::int32_t *c, std::int64_t ldc, std::int64_t stride_c, + std::int64_t batch_size, const std::vector &dependencies); sycl::event (*row_major_sgemmt_usm_sycl)(sycl::queue &queue, oneapi::mkl::uplo upper_lower, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, std::int64_t n, @@ -4683,6 +4860,32 @@ typedef struct { const std::complex *a, std::int64_t lda, std::complex *b, std::int64_t ldb, const std::vector &dependencies); + sycl::event (*row_major_somatcopy2_usm_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, float alpha, + const float *a, std::int64_t lda, + std::int64_t stridea, float *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies); + sycl::event (*row_major_domatcopy2_usm_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, double alpha, + const double *a, std::int64_t lda, + std::int64_t stridea, double *b, std::int64_t ldb, + std::int64_t strideb, + const std::vector &dependencies); + sycl::event (*row_major_comatcopy2_usm_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, + std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies); + sycl::event (*row_major_zomatcopy2_usm_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, + std::int64_t m, std::int64_t n, + std::complex alpha, + const std::complex *a, std::int64_t lda, + std::int64_t stridea, std::complex *b, + std::int64_t ldb, std::int64_t strideb, + const std::vector &dependencies); sycl::event (*row_major_simatcopy_usm_sycl)(sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t m, std::int64_t n, float alpha, float *ab, std::int64_t lda, std::int64_t ldb, diff --git a/src/config.hpp.in b/src/config.hpp.in index 4540d0474..5698abf9b 100644 --- a/src/config.hpp.in +++ b/src/config.hpp.in @@ -21,14 +21,22 @@ #define ONEMKL_CONFIG_H #cmakedefine ENABLE_CUBLAS_BACKEND -#cmakedefine ENABLE_CUSOLVER_BACKEND -#cmakedefine ENABLE_ROCBLAS_BACKEND -#cmakedefine ENABLE_ROCRAND_BACKEND -#cmakedefine ENABLE_ROCSOLVER_BACKEND +#cmakedefine ENABLE_CUFFT_BACKEND #cmakedefine ENABLE_CURAND_BACKEND +#cmakedefine ENABLE_CUSOLVER_BACKEND #cmakedefine ENABLE_MKLCPU_BACKEND #cmakedefine ENABLE_MKLGPU_BACKEND #cmakedefine ENABLE_NETLIB_BACKEND +#cmakedefine ENABLE_PORTBLAS_BACKEND +#cmakedefine ENABLE_PORTBLAS_BACKEND_AMD_GPU +#cmakedefine ENABLE_PORTBLAS_BACKEND_INTEL_CPU +#cmakedefine ENABLE_PORTBLAS_BACKEND_INTEL_GPU +#cmakedefine ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU +#cmakedefine ENABLE_PORTFFT_BACKEND +#cmakedefine ENABLE_ROCBLAS_BACKEND +#cmakedefine ENABLE_ROCFFT_BACKEND +#cmakedefine ENABLE_ROCRAND_BACKEND +#cmakedefine ENABLE_ROCSOLVER_BACKEND #cmakedefine BUILD_SHARED_LIBS #cmakedefine REF_BLAS_LIBNAME "@REF_BLAS_LIBNAME@" #cmakedefine REF_CBLAS_LIBNAME "@REF_CBLAS_LIBNAME@" diff --git a/src/dft/CMakeLists.txt b/src/dft/CMakeLists.txt index a12cac5e2..e3b373645 100644 --- a/src/dft/CMakeLists.txt +++ b/src/dft/CMakeLists.txt @@ -29,6 +29,7 @@ target_include_directories(onemkl_dft ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/src/include ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} $ ) @@ -43,4 +44,7 @@ else() target_link_libraries(onemkl_dft PUBLIC ONEMKL::SYCL::SYCL) endif() +include(WarningsUtils) +target_link_libraries(onemkl_dft PRIVATE onemkl_warnings) + endif() diff --git a/src/dft/backends/CMakeLists.txt b/src/dft/backends/CMakeLists.txt index f0cc973ba..b03a63e8a 100644 --- a/src/dft/backends/CMakeLists.txt +++ b/src/dft/backends/CMakeLists.txt @@ -17,6 +17,9 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== +add_custom_target(onemkl_backend_libs_dft) +add_dependencies(onemkl_backend_libs onemkl_backend_libs_dft) + if(ENABLE_MKLGPU_BACKEND) add_subdirectory(mklgpu) endif() @@ -24,3 +27,15 @@ endif() if(ENABLE_MKLCPU_BACKEND) add_subdirectory(mklcpu) endif() + +if(ENABLE_CUFFT_BACKEND) + add_subdirectory(cufft) +endif() + +if(ENABLE_ROCFFT_BACKEND) + add_subdirectory(rocfft) +endif() + +if(ENABLE_PORTFFT_BACKEND) + add_subdirectory(portfft) +endif() diff --git a/src/dft/backends/backend_backward_instantiations.cxx b/src/dft/backends/backend_backward_instantiations.cxx index 37c1d6885..a6aeaf71b 100644 --- a/src/dft/backends/backend_backward_instantiations.cxx +++ b/src/dft/backends/backend_backward_instantiations.cxx @@ -27,49 +27,31 @@ using desc_cd_t = dft::detail::descriptor; using depends_vec_t = const std::vector &; -#define ONEMKL_DFT_BACKWARD_INSTANTIATIONS(DESCRIPTOR_T, REAL_T, FORWARD_T, BACKWARD_T) \ - /* Buffer API */ \ - template ONEMKL_EXPORT void compute_backward(DESCRIPTOR_T & desc, \ - sycl::buffer &); \ - template ONEMKL_EXPORT void compute_backward( \ - DESCRIPTOR_T & desc, sycl::buffer &); \ - template ONEMKL_EXPORT void compute_backward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &); \ - template ONEMKL_EXPORT void compute_backward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &); \ - template ONEMKL_EXPORT void compute_backward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &); \ - template ONEMKL_EXPORT void compute_backward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &, \ - sycl::buffer &, sycl::buffer &); \ - \ - /* USM API */ \ - template ONEMKL_EXPORT sycl::event compute_backward( \ - DESCRIPTOR_T & desc, REAL_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_backward( \ - DESCRIPTOR_T & desc, BACKWARD_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_backward( \ - DESCRIPTOR_T & desc, BACKWARD_T *, FORWARD_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_backward( \ - DESCRIPTOR_T & desc, REAL_T *, REAL_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_backward( \ - DESCRIPTOR_T & desc, REAL_T *, REAL_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_backward( \ - DESCRIPTOR_T & desc, REAL_T *, REAL_T *, REAL_T *, REAL_T *, depends_vec_t); - -#define ONEMKL_DFT_BACKWARD_INSTANTIATIONS_REAL_ONLY(DESCRIPTOR_T, COMPLEX_T) \ - /* Buffer API */ \ - template ONEMKL_EXPORT void compute_backward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &); \ - /* USM API */ \ - template ONEMKL_EXPORT sycl::event compute_backward( \ - DESCRIPTOR_T & desc, COMPLEX_T *, COMPLEX_T *, depends_vec_t); +#define ONEMKL_DFT_BACKWARD_INSTANTIATIONS(DESCRIPTOR_T, SCALAR_T, FORWARD_T, BACKWARD_T) \ + /* Buffer API */ \ + template ONEMKL_EXPORT void compute_backward(DESCRIPTOR_T &, \ + sycl::buffer &); \ + template ONEMKL_EXPORT void compute_backward( \ + DESCRIPTOR_T &, sycl::buffer &, sycl::buffer &); \ + template ONEMKL_EXPORT void compute_backward( \ + DESCRIPTOR_T &, sycl::buffer &, sycl::buffer &); \ + template ONEMKL_EXPORT void compute_backward( \ + DESCRIPTOR_T &, sycl::buffer &, sycl::buffer &, \ + sycl::buffer &, sycl::buffer &); \ + \ + /* USM API */ \ + template ONEMKL_EXPORT sycl::event compute_backward(DESCRIPTOR_T &, FORWARD_T *, \ + depends_vec_t); \ + template ONEMKL_EXPORT sycl::event compute_backward(DESCRIPTOR_T &, SCALAR_T *, \ + SCALAR_T *, depends_vec_t); \ + template ONEMKL_EXPORT sycl::event compute_backward( \ + DESCRIPTOR_T &, BACKWARD_T *, FORWARD_T *, depends_vec_t); \ + template ONEMKL_EXPORT sycl::event compute_backward( \ + DESCRIPTOR_T &, SCALAR_T *, SCALAR_T *, SCALAR_T *, SCALAR_T *, depends_vec_t); ONEMKL_DFT_BACKWARD_INSTANTIATIONS(desc_rf_t, float, float, std::complex) -ONEMKL_DFT_BACKWARD_INSTANTIATIONS_REAL_ONLY(desc_rf_t, std::complex) ONEMKL_DFT_BACKWARD_INSTANTIATIONS(desc_cf_t, float, std::complex, std::complex) ONEMKL_DFT_BACKWARD_INSTANTIATIONS(desc_rd_t, double, double, std::complex) -ONEMKL_DFT_BACKWARD_INSTANTIATIONS_REAL_ONLY(desc_rd_t, std::complex) ONEMKL_DFT_BACKWARD_INSTANTIATIONS(desc_cd_t, double, std::complex, std::complex) #undef ONEMKL_DFT_BACKWARD_INSTANTIATIONS diff --git a/src/dft/backends/backend_compute_signature.cxx b/src/dft/backends/backend_compute_signature.cxx index ebfbe2145..d011cb995 100644 --- a/src/dft/backends/backend_compute_signature.cxx +++ b/src/dft/backends/backend_compute_signature.cxx @@ -18,334 +18,120 @@ *******************************************************************************/ /* -When only a specific backend library is required (eg. libonemkl_dft_) -it may be preferenable to only link to that specific backend library without -the requirement that the main OneMKL library also be linked. - -To enable this, function signatures from the main dft library are duplicated -here, forwarding directly to the backend implementation instead of the function -table lookup mechanism. +repetitive definitions from commit.cpp. This file should be included for each backend, with defined to match the namespace of the backend's implementation. */ -#include "oneapi/mkl/dft/forward.hpp" -#include "oneapi/mkl/dft/backward.hpp" +using fwd_type = typename dft::detail::commit_impl::fwd_type; +using bwd_type = typename dft::detail::commit_impl::bwd_type; +using descriptor_type = typename dft::detail::descriptor; + +// forward inplace COMPLEX_COMPLEX +void forward_ip_cc(descriptor_type& desc, sycl::buffer& inout) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_forward"); + oneapi::mkl::dft::BACKEND::compute_forward(desc, inout); +} +sycl::event forward_ip_cc(descriptor_type& desc, fwd_type* inout, + const std::vector& dependencies) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_forward"); + return oneapi::mkl::dft::BACKEND::compute_forward(desc, inout, dependencies); +} + +// forward inplace REAL_REAL +void forward_ip_rr(descriptor_type& desc, sycl::buffer& inout_re, + sycl::buffer& inout_im) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_forward"); + oneapi::mkl::dft::BACKEND::compute_forward(desc, inout_re, inout_im); +} +sycl::event forward_ip_rr(descriptor_type& desc, scalar_type* inout_re, scalar_type* inout_im, + const std::vector& dependencies) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_forward"); + return oneapi::mkl::dft::BACKEND::compute_forward(desc, inout_re, inout_im, dependencies); +} + +// forward out-of-place COMPLEX_COMPLEX +void forward_op_cc(descriptor_type& desc, sycl::buffer& in, + sycl::buffer& out) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_forward"); + oneapi::mkl::dft::BACKEND::compute_forward(desc, in, out); +} +sycl::event forward_op_cc(descriptor_type& desc, fwd_type* in, bwd_type* out, + const std::vector& dependencies) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_forward"); + return oneapi::mkl::dft::BACKEND::compute_forward(desc, in, out, dependencies); +} -namespace oneapi { -namespace mkl { -namespace dft { +// forward out-of-place REAL_REAL +void forward_op_rr(descriptor_type& desc, sycl::buffer& in_re, + sycl::buffer& in_im, sycl::buffer& out_re, + sycl::buffer& out_im) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_forward"); + oneapi::mkl::dft::BACKEND::compute_forward(desc, in_re, in_im, out_re, out_im); +} +sycl::event forward_op_rr(descriptor_type& desc, scalar_type* in_re, scalar_type* in_im, + scalar_type* out_re, scalar_type* out_im, + const std::vector& dependencies) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_forward"); + return oneapi::mkl::dft::BACKEND::compute_forward(desc, in_re, in_im, out_re, out_im, + dependencies); +} -#define ONEAPI_MKL_DFT_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - \ - /*Buffer version*/ \ - \ - /*In-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT void compute_forward, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & inout) { \ - oneapi::mkl::dft::BACKEND::compute_forward, \ - T_REAL>(desc, inout); \ - } \ - \ - /*In-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT void compute_forward, T_BACKWARD>( \ - dft::detail::descriptor & desc, sycl::buffer & inout) { \ - oneapi::mkl::dft::BACKEND::compute_forward, \ - T_BACKWARD>(desc, inout); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT void compute_forward, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & inout_re, \ - sycl::buffer & inout_im) { \ - oneapi::mkl::dft::BACKEND::compute_forward, \ - T_REAL>(desc, inout_re, inout_im); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_forward, T_FORWARD, T_BACKWARD>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - oneapi::mkl::dft::BACKEND::compute_forward, \ - T_FORWARD, T_BACKWARD>(desc, in, out); \ - } \ - \ - /*Out-of-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_forward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - oneapi::mkl::dft::BACKEND::compute_forward, \ - T_REAL, T_REAL>(desc, in, out); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_forward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & in_re, \ - sycl::buffer & in_im, sycl::buffer & out_re, \ - sycl::buffer & out_im) { \ - oneapi::mkl::dft::BACKEND::compute_forward, \ - T_REAL, T_REAL>(desc, in_re, in_im, out_re, \ - out_im); \ - } \ - \ - /*USM version*/ \ - \ - /*In-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT sycl::event compute_forward, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * inout, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_forward< \ - dft::detail::descriptor, T_REAL>(desc, inout, dependencies); \ - } \ - \ - /*In-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_forward, T_BACKWARD>( \ - dft::detail::descriptor & desc, T_BACKWARD * inout, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_forward< \ - dft::detail::descriptor, T_BACKWARD>(desc, inout, dependencies); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT sycl::event compute_forward, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * inout_re, T_REAL * inout_im, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_forward< \ - dft::detail::descriptor, T_REAL>(desc, inout_re, inout_im, \ - dependencies); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_forward, T_FORWARD, T_BACKWARD>( \ - dft::detail::descriptor & desc, T_FORWARD * in, T_BACKWARD * out, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_forward< \ - dft::detail::descriptor, T_FORWARD, T_BACKWARD>(desc, in, out, \ - dependencies); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_forward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * in, T_REAL * out, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_forward< \ - dft::detail::descriptor, T_REAL, T_REAL>(desc, in, out, \ - dependencies); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_forward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * in_re, T_REAL * in_im, \ - T_REAL * out_re, T_REAL * out_im, const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_forward< \ - dft::detail::descriptor, T_REAL, T_REAL>( \ - desc, in_re, in_im, out_re, out_im, dependencies); \ - } \ - \ - /*Buffer version*/ \ - \ - /*In-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT void compute_backward, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & inout) { \ - oneapi::mkl::dft::BACKEND::compute_backward, \ - T_REAL>(desc, inout); \ - } \ - \ - /*In-place transform - complex */ \ - template <> \ - ONEMKL_EXPORT void compute_backward, T_BACKWARD>( \ - dft::detail::descriptor & desc, sycl::buffer & inout) { \ - oneapi::mkl::dft::BACKEND::compute_backward, \ - T_BACKWARD>(desc, inout); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT void compute_backward, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & inout_re, \ - sycl::buffer & inout_im) { \ - oneapi::mkl::dft::BACKEND::compute_backward, \ - T_REAL>(desc, inout_re, inout_im); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_backward, T_BACKWARD, T_FORWARD>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - oneapi::mkl::dft::BACKEND::compute_backward, \ - T_BACKWARD, T_FORWARD>(desc, in, out); \ - } \ - \ - /*Out-of-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_backward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - oneapi::mkl::dft::BACKEND::compute_backward, \ - T_REAL, T_REAL>(desc, in, out); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_backward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & in_re, \ - sycl::buffer & in_im, sycl::buffer & out_re, \ - sycl::buffer & out_im) { \ - oneapi::mkl::dft::BACKEND::compute_backward, \ - T_REAL, T_REAL>(desc, in_re, in_im, out_re, \ - out_im); \ - } \ - \ - /*USM version*/ \ - \ - /*In-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * inout, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_backward< \ - dft::detail::descriptor, T_REAL>(desc, inout, dependencies); \ - } \ - \ - /*In-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_BACKWARD>( \ - dft::detail::descriptor & desc, T_BACKWARD * inout, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_backward< \ - dft::detail::descriptor, T_BACKWARD>(desc, inout, dependencies); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * inout_re, T_REAL * inout_im, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_backward< \ - dft::detail::descriptor, T_REAL>(desc, inout_re, inout_im, \ - dependencies); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_BACKWARD, T_FORWARD>( \ - dft::detail::descriptor & desc, T_BACKWARD * in, T_FORWARD * out, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_backward< \ - dft::detail::descriptor, T_BACKWARD, T_FORWARD>(desc, in, out, \ - dependencies); \ - } \ - \ - /*Out-of-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * in, T_REAL * out, \ - const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_backward< \ - dft::detail::descriptor, T_REAL, T_REAL>(desc, in, out, \ - dependencies); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * in_re, T_REAL * in_im, \ - T_REAL * out_re, T_REAL * out_im, const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_backward< \ - dft::detail::descriptor, T_REAL, T_REAL>( \ - desc, in_re, in_im, out_re, out_im, dependencies); \ - } +// backward inplace COMPLEX_COMPLEX +void backward_ip_cc(descriptor_type& desc, sycl::buffer& inout) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_backward"); + oneapi::mkl::dft::BACKEND::compute_backward(desc, inout); +} +sycl::event backward_ip_cc(descriptor_type& desc, fwd_type* inout, + const std::vector& dependencies) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_backward"); + return oneapi::mkl::dft::BACKEND::compute_backward(desc, inout, dependencies); +} -// Signatures with forward_t=complex, backwards_t=complex are already instantiated for complex domain -// but not real domain. -#define ONEAPI_MKL_DFT_REAL_ONLY_SIGNATURES(EXT, PRECISION, T_COMPLEX) \ - /*Out-of-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT void compute_forward, \ - T_COMPLEX, T_COMPLEX>( \ - dft::detail::descriptor & desc, \ - sycl::buffer & in, sycl::buffer & out) { \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - dft::detail::descriptor, T_COMPLEX, T_COMPLEX>( \ - desc, in, out); \ - } \ - \ - /*Out-of-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT sycl::event compute_forward< \ - dft::detail::descriptor, T_COMPLEX, T_COMPLEX>( \ - dft::detail::descriptor & desc, T_COMPLEX * in, \ - T_COMPLEX * out, const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_forward< \ - dft::detail::descriptor, T_COMPLEX, T_COMPLEX>( \ - desc, in, out, dependencies); \ - } \ - \ - /*Out-of-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT void compute_backward, \ - T_COMPLEX, T_COMPLEX>( \ - dft::detail::descriptor & desc, \ - sycl::buffer & in, sycl::buffer & out) { \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - dft::detail::descriptor, T_COMPLEX, T_COMPLEX>( \ - desc, in, out); \ - } \ - \ - /*Out-of-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT sycl::event compute_backward< \ - dft::detail::descriptor, T_COMPLEX, T_COMPLEX>( \ - dft::detail::descriptor & desc, T_COMPLEX * in, \ - T_COMPLEX * out, const std::vector& dependencies) { \ - return oneapi::mkl::dft::BACKEND::compute_backward< \ - dft::detail::descriptor, T_COMPLEX, T_COMPLEX>( \ - desc, in, out, dependencies); \ - } +// backward inplace REAL_REAL +void backward_ip_rr(descriptor_type& desc, sycl::buffer& inout_re, + sycl::buffer& inout_im) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_backward"); + oneapi::mkl::dft::BACKEND::compute_backward(desc, inout_re, inout_im); +} +sycl::event backward_ip_rr(descriptor_type& desc, scalar_type* inout_re, scalar_type* inout_im, + const std::vector& dependencies) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_backward"); + return oneapi::mkl::dft::BACKEND::compute_backward(desc, inout_re, inout_im, dependencies); +} -ONEAPI_MKL_DFT_SIGNATURES(f, dft::detail::precision::SINGLE, dft::detail::domain::REAL, float, - float, std::complex) -ONEAPI_MKL_DFT_REAL_ONLY_SIGNATURES(f, dft::detail::precision::SINGLE, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(c, dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX, float, - std::complex, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(d, dft::detail::precision::DOUBLE, dft::detail::domain::REAL, double, - double, std::complex) -ONEAPI_MKL_DFT_REAL_ONLY_SIGNATURES(d, dft::detail::precision::DOUBLE, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(z, dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX, double, - std::complex, std::complex) -#undef ONEAPI_MKL_DFT_SIGNATURES +// backward out-of-place COMPLEX_COMPLEX +void backward_op_cc(descriptor_type& desc, sycl::buffer& in, + sycl::buffer& out) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_backward"); + oneapi::mkl::dft::BACKEND::compute_backward(desc, in, out); +} +sycl::event backward_op_cc(descriptor_type& desc, bwd_type* in, fwd_type* out, + const std::vector& dependencies) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_backward"); + return oneapi::mkl::dft::BACKEND::compute_backward(desc, in, out, dependencies); +} -} // namespace dft -} // namespace mkl -} // namespace oneapi +// backward out-of-place REAL_REAL +void backward_op_rr(descriptor_type& desc, sycl::buffer& in_re, + sycl::buffer& in_im, sycl::buffer& out_re, + sycl::buffer& out_im) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_backward"); + oneapi::mkl::dft::BACKEND::compute_backward(desc, in_re, in_im, out_re, out_im); +} +sycl::event backward_op_rr(descriptor_type& desc, scalar_type* in_re, scalar_type* in_im, + scalar_type* out_re, scalar_type* out_im, + const std::vector& dependencies) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_backward"); + return oneapi::mkl::dft::BACKEND::compute_backward(desc, in_re, in_im, out_re, out_im, + dependencies); +} diff --git a/src/dft/backends/backend_forward_instantiations.cxx b/src/dft/backends/backend_forward_instantiations.cxx index 2c4a881d7..a6ed371d5 100644 --- a/src/dft/backends/backend_forward_instantiations.cxx +++ b/src/dft/backends/backend_forward_instantiations.cxx @@ -27,49 +27,31 @@ using desc_cd_t = dft::detail::descriptor; using depends_vec_t = const std::vector &; -#define ONEMKL_DFT_FORWARD_INSTANTIATIONS(DESCRIPTOR_T, REAL_T, FORWARD_T, BACKWARD_T) \ - /* Buffer API */ \ - template ONEMKL_EXPORT void compute_forward(DESCRIPTOR_T & desc, \ - sycl::buffer &); \ - template ONEMKL_EXPORT void compute_forward( \ - DESCRIPTOR_T & desc, sycl::buffer &); \ - template ONEMKL_EXPORT void compute_forward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &); \ - template ONEMKL_EXPORT void compute_forward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &); \ - template ONEMKL_EXPORT void compute_forward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &); \ - template ONEMKL_EXPORT void compute_forward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &, \ - sycl::buffer &, sycl::buffer &); \ - \ - /* USM API */ \ - template ONEMKL_EXPORT sycl::event compute_forward( \ - DESCRIPTOR_T & desc, REAL_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_forward( \ - DESCRIPTOR_T & desc, BACKWARD_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_forward( \ - DESCRIPTOR_T & desc, FORWARD_T *, BACKWARD_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_forward( \ - DESCRIPTOR_T & desc, REAL_T *, REAL_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_forward( \ - DESCRIPTOR_T & desc, REAL_T *, REAL_T *, depends_vec_t); \ - template ONEMKL_EXPORT sycl::event compute_forward( \ - DESCRIPTOR_T & desc, REAL_T *, REAL_T *, REAL_T *, REAL_T *, depends_vec_t); - -#define ONEMKL_DFT_FORWARD_INSTANTIATIONS_REAL_ONLY(DESCRIPTOR_T, COMPLEX_T) \ - /* Buffer API */ \ - template ONEMKL_EXPORT void compute_forward( \ - DESCRIPTOR_T & desc, sycl::buffer &, sycl::buffer &); \ - /* USM API */ \ - template ONEMKL_EXPORT sycl::event compute_forward( \ - DESCRIPTOR_T & desc, COMPLEX_T *, COMPLEX_T *, depends_vec_t); +#define ONEMKL_DFT_FORWARD_INSTANTIATIONS(DESCRIPTOR_T, SCALAR_T, FORWARD_T, BACKWARD_T) \ + /* Buffer API */ \ + template ONEMKL_EXPORT void compute_forward(DESCRIPTOR_T &, \ + sycl::buffer &); \ + template ONEMKL_EXPORT void compute_forward( \ + DESCRIPTOR_T &, sycl::buffer &, sycl::buffer &); \ + template ONEMKL_EXPORT void compute_forward( \ + DESCRIPTOR_T &, sycl::buffer &, sycl::buffer &); \ + template ONEMKL_EXPORT void compute_forward( \ + DESCRIPTOR_T &, sycl::buffer &, sycl::buffer &, \ + sycl::buffer &, sycl::buffer &); \ + \ + /* USM API */ \ + template ONEMKL_EXPORT sycl::event compute_forward(DESCRIPTOR_T &, FORWARD_T *, \ + depends_vec_t); \ + template ONEMKL_EXPORT sycl::event compute_forward(DESCRIPTOR_T &, SCALAR_T *, \ + SCALAR_T *, depends_vec_t); \ + template ONEMKL_EXPORT sycl::event compute_forward(DESCRIPTOR_T &, FORWARD_T *, \ + BACKWARD_T *, depends_vec_t); \ + template ONEMKL_EXPORT sycl::event compute_forward( \ + DESCRIPTOR_T &, SCALAR_T *, SCALAR_T *, SCALAR_T *, SCALAR_T *, depends_vec_t); ONEMKL_DFT_FORWARD_INSTANTIATIONS(desc_rf_t, float, float, std::complex) -ONEMKL_DFT_FORWARD_INSTANTIATIONS_REAL_ONLY(desc_rf_t, std::complex) ONEMKL_DFT_FORWARD_INSTANTIATIONS(desc_cf_t, float, std::complex, std::complex) ONEMKL_DFT_FORWARD_INSTANTIATIONS(desc_rd_t, double, double, std::complex) -ONEMKL_DFT_FORWARD_INSTANTIATIONS_REAL_ONLY(desc_rd_t, std::complex) ONEMKL_DFT_FORWARD_INSTANTIATIONS(desc_cd_t, double, std::complex, std::complex) #undef ONEMKL_DFT_FORWARD_INSTANTIATIONS diff --git a/src/dft/backends/backend_wrappers.cxx b/src/dft/backends/backend_wrappers.cxx index 5635eda52..5d0d2bddc 100644 --- a/src/dft/backends/backend_wrappers.cxx +++ b/src/dft/backends/backend_wrappers.cxx @@ -41,80 +41,6 @@ oneapi::mkl::dft::BACKEND::create_commit, oneapi::mkl::dft::BACKEND::create_commit, oneapi::mkl::dft::BACKEND::create_commit, oneapi::mkl::dft::BACKEND::create_commit, -#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - /* Buffer API */ \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_FORWARD, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL, T_REAL>, \ - /* USM API */ \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_FORWARD, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_forward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL, T_REAL>, \ - /* Buffer API */ \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD, T_FORWARD>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL, T_REAL>, \ - /* USM API */ \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD, T_FORWARD>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL, T_REAL>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_BACKWARD, T_BACKWARD>, \ - oneapi::mkl::dft::BACKEND::compute_backward< \ - oneapi::mkl::dft::detail::descriptor, T_REAL, T_REAL>, - -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(oneapi::mkl::dft::detail::precision::SINGLE, - oneapi::mkl::dft::detail::domain::REAL, float, float, - std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(oneapi::mkl::dft::detail::precision::SINGLE, - oneapi::mkl::dft::detail::domain::COMPLEX, float, - std::complex, std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(oneapi::mkl::dft::detail::precision::DOUBLE, - oneapi::mkl::dft::detail::domain::REAL, double, - double, std::complex) -ONEAPI_MKL_DFT_BACKEND_SIGNATURES(oneapi::mkl::dft::detail::precision::DOUBLE, - oneapi::mkl::dft::detail::domain::COMPLEX, double, - std::complex, std::complex) // clang-format on #undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES diff --git a/src/dft/backends/cufft/CMakeLists.txt b/src/dft/backends/cufft/CMakeLists.txt new file mode 100644 index 000000000..010905546 --- /dev/null +++ b/src/dft/backends/cufft/CMakeLists.txt @@ -0,0 +1,85 @@ +#=============================================================================== +# Copyright Codeplay Software Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemkl_dft_cufft) +set(LIB_OBJ ${LIB_NAME}_obj) + + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + descriptor.cpp + commit.cpp + forward.cpp + backward.cpp + $<$: mkl_dft_cufft_wrappers.cpp> +) +add_dependencies(onemkl_backend_libs_dft ${LIB_NAME}) + +target_include_directories(${LIB_OBJ} + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) +target_include_directories(${LIB_NAME} + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + +if (${CMAKE_VERSION} VERSION_LESS "3.17.0") + find_package(CUDA REQUIRED) + target_include_directories(${LIB_OBJ} PRIVATE ${CUDA_INCLUDE_DIRS}) + target_link_libraries(${LIB_OBJ} PRIVATE cuda ${CUDA_CUFFT_LIBRARIES}) +else() + find_package(CUDAToolkit REQUIRED) + target_link_libraries(${LIB_OBJ} PRIVATE CUDA::cufft CUDA::cuda_driver) +endif() + +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +#Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/dft/backends/cufft/backward.cpp b/src/dft/backends/cufft/backward.cpp new file mode 100644 index 000000000..aea9f232f --- /dev/null +++ b/src/dft/backends/cufft/backward.cpp @@ -0,0 +1,245 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/exceptions.hpp" + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp" +#include "oneapi/mkl/dft/types.hpp" + +#include "execute_helper.hpp" + +#include + +namespace oneapi::mkl::dft::cufft { +namespace detail { +//forward declaration +template +std::array get_offsets_bwd(dft::detail::commit_impl *commit); + +template +cufftHandle get_bwd_plan(dft::detail::commit_impl *commit) { + return static_cast *>(commit->get_handle())[1].value(); +} +} // namespace detail +// BUFFER version + +//In-place transform +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &inout) { + const std::string func_name = "compute_backward(desc, inout)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto offsets = detail::get_offsets_bwd(commit); + + if constexpr (std::is_floating_point_v>) { + offsets[0] *= 2; // offset is supplied in complex but we offset scalar pointer + if (offsets[1] % 2 != 0) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "cuFFT requires offset (first value in strides) to be multiple of 2!"); + } + } + + queue.submit([&](sycl::handler &cgh) { + auto inout_acc = inout.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, plan); + + auto inout_native = reinterpret_cast *>( + ih.get_native_mem(inout_acc)); + detail::cufft_execute>( + func_name, stream, plan, reinterpret_cast(inout_native + offsets[0]), + reinterpret_cast(inout_native + offsets[1])); + }); + }); +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_backward(descriptor_type &, sycl::buffer, 1> &, + sycl::buffer, 1> &) { + throw oneapi::mkl::unimplemented("DFT", "compute_backward(desc, inout_re, inout_im)", + "cuFFT does not support real-real complex storage."); +} + +//Out-of-place transform +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + const std::string func_name = "compute_backward(desc, in, out)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto offsets = detail::get_offsets_bwd(commit); + + if constexpr (std::is_floating_point_v>) { + if (offsets[1] % 2 != 0) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "cuFFT requires offset (first value in strides) to be multiple of 2!"); + } + } + + queue.submit([&](sycl::handler &cgh) { + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, plan); + + auto in_native = reinterpret_cast( + reinterpret_cast *>( + ih.get_native_mem(in_acc)) + + offsets[0]); + auto out_native = reinterpret_cast( + reinterpret_cast *>( + ih.get_native_mem(out_acc)) + + offsets[1]); + detail::cufft_execute>( + func_name, stream, plan, in_native, out_native); + }); + }); +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_backward(descriptor_type &, sycl::buffer, 1> &, + sycl::buffer, 1> &, + sycl::buffer, 1> &, + sycl::buffer, 1> &) { + throw oneapi::mkl::unimplemented("DFT", "compute_backward(desc, in_re, in_im, out_re, out_im)", + "cuFFT does not support real-real complex storage."); +} + +//USM version + +//In-place transform +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd *inout, + const std::vector &dependencies) { + const std::string func_name = "compute_backward(desc, inout, dependencies)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto offsets = detail::get_offsets_bwd(commit); + + if constexpr (std::is_floating_point_v>) { + offsets[0] *= 2; // offset is supplied in complex but we offset scalar pointer + if (offsets[1] % 2 != 0) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "cuFFT requires offset (first value in strides) to be multiple of 2!"); + } + } + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, plan); + + detail::cufft_execute>( + func_name, stream, plan, inout + offsets[0], inout + offsets[1]); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &, scalar *, + scalar *, + const std::vector &) { + throw oneapi::mkl::unimplemented("DFT", + "compute_backward(desc, inout_re, inout_im, dependencies)", + "cuFFT does not support real-real complex storage."); +} + +//Out-of-place transform +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd *in, + fwd *out, + const std::vector &dependencies) { + const std::string func_name = "compute_backward(desc, in, out, dependencies)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto offsets = detail::get_offsets_bwd(commit); + + if constexpr (std::is_floating_point_v>) { + if (offsets[1] % 2 != 0) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "cuFFT requires offset (first value in strides) to be multiple of 2!"); + } + } + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, plan); + + detail::cufft_execute>( + func_name, stream, plan, in + offsets[0], out + offsets[1]); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &, scalar *, + scalar *, scalar *, + scalar *, + const std::vector &) { + throw oneapi::mkl::unimplemented("DFT", + "compute_backward(desc, in_re, in_im, out_re, out_im, deps)", + "cuFFT does not support real-real complex storage."); +} + +// Template function instantiations +#include "dft/backends/backend_backward_instantiations.cxx" + +} // namespace oneapi::mkl::dft::cufft diff --git a/src/dft/backends/cufft/commit.cpp b/src/dft/backends/cufft/commit.cpp new file mode 100644 index 000000000..faf4332c0 --- /dev/null +++ b/src/dft/backends/cufft/commit.cpp @@ -0,0 +1,462 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include +#include +#include + +#include "oneapi/mkl/exceptions.hpp" + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp" +#include "oneapi/mkl/dft/types.hpp" + +#include "../stride_helper.hpp" + +#include +#include + +namespace oneapi::mkl::dft::cufft { +namespace detail { + +/// Commit impl class specialization for cuFFT. +template +class cufft_commit final : public dft::detail::commit_impl { +private: + using scalar_type = typename dft::detail::commit_impl::scalar_type; + + // For real to complex transforms, the "type" arg also encodes the direction (e.g. CUFFT_R2C vs CUFFT_C2R) in the plan so we must have one for each direction. + // We also need this because oneMKL uses a directionless "FWD_DISTANCE" and "BWD_DISTANCE" while cuFFT uses a directional "idist" and "odist". + // plans[0] is forward, plans[1] is backward + std::array, 2> plans = { std::nullopt, std::nullopt }; + std::int64_t offset_fwd_in, offset_fwd_out, offset_bwd_in, offset_bwd_out; + +public: + cufft_commit(sycl::queue& queue, const dft::detail::dft_values& config_values) + : oneapi::mkl::dft::detail::commit_impl(queue, backend::cufft, + config_values) { + if constexpr (prec == dft::detail::precision::DOUBLE) { + if (!queue.get_device().has(sycl::aspect::fp64)) { + throw mkl::exception("DFT", "commit", "Device does not support double precision."); + } + } + } + + void clean_plans() { + auto fix_context = plans[0].has_value() || plans[1].has_value(); + if (plans[0]) { + if (cufftDestroy(plans[0].value()) != CUFFT_SUCCESS) { + throw mkl::exception("dft/backends/cufft", __FUNCTION__, + "Failed to destroy forward cuFFT plan."); + } + plans[0] = std::nullopt; + } + if (plans[1]) { + if (cufftDestroy(plans[1].value()) != CUFFT_SUCCESS) { + throw mkl::exception("dft/backends/cufft", __FUNCTION__, + "Failed to destroy backward cuFFT plan."); + } + plans[1] = std::nullopt; + } + if (fix_context) { + // cufftDestroy changes the context so change it back. + CUdevice interopDevice = + sycl::get_native(this->get_queue().get_device()); + CUcontext interopContext; + if (cuDevicePrimaryCtxRetain(&interopContext, interopDevice) != CUDA_SUCCESS) { + throw mkl::exception("dft/backends/cufft", __FUNCTION__, + "Failed to change cuda context."); + } + } + } + + void commit(const dft::detail::dft_values& config_values) override { + // this could be a recommit + this->external_workspace_helper_ = + oneapi::mkl::dft::detail::external_workspace_helper( + config_values.workspace_placement == + oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL); + clean_plans(); + + if (config_values.fwd_scale != 1.0 || config_values.bwd_scale != 1.0) { + throw mkl::unimplemented( + "dft/backends/cufft", __FUNCTION__, + "cuFFT does not support values other than 1 for FORWARD/BACKWARD_SCALE"); + } + + // The cudaStream for the plan is set at execution time so the interop handler can pick the stream. + constexpr cufftType fwd_type = [] { + if constexpr (dom == dft::domain::COMPLEX) { + if constexpr (prec == dft::precision::SINGLE) { + return CUFFT_C2C; + } + else { + return CUFFT_Z2Z; + } + } + else { + if constexpr (prec == dft::precision::SINGLE) { + return CUFFT_R2C; + } + else { + return CUFFT_D2Z; + } + } + }(); + constexpr cufftType bwd_type = [] { + if constexpr (dom == dft::domain::COMPLEX) { + if constexpr (prec == dft::precision::SINGLE) { + return CUFFT_C2C; + } + else { + return CUFFT_Z2Z; + } + } + else { + if constexpr (prec == dft::precision::SINGLE) { + return CUFFT_C2R; + } + else { + return CUFFT_Z2D; + } + } + }(); + + constexpr std::size_t max_supported_dims = 3; + std::array n_copy; + std::copy(config_values.dimensions.begin(), config_values.dimensions.end(), n_copy.data()); + const int rank = static_cast(config_values.dimensions.size()); + + auto stride_api_choice = dft::detail::get_stride_api(config_values); + dft::detail::throw_on_invalid_stride_api("CUFFT commit", stride_api_choice); + dft::detail::stride_vectors stride_vecs(config_values, stride_api_choice); + offset_fwd_in = stride_vecs.offset_fwd_in; + offset_fwd_out = stride_vecs.offset_fwd_out; + offset_bwd_in = stride_vecs.offset_bwd_in; + offset_bwd_out = stride_vecs.offset_bwd_out; + + // cufft ignores the first value in inembed and onembed, so there is no harm in putting offset there + auto a_min = std::min_element(stride_vecs.vec_a.begin() + 1, stride_vecs.vec_a.end()); + auto b_min = std::min_element(stride_vecs.vec_b.begin() + 1, stride_vecs.vec_b.end()); + if constexpr (dom == dft::domain::REAL) { + if ((a_min != stride_vecs.vec_a.begin() + rank) || + (b_min != stride_vecs.vec_b.begin() + rank)) { + throw mkl::unimplemented( + "dft/backends/cufft", __FUNCTION__, + "cufft requires the last stride to be the the smallest one for real transforms!"); + } + } + else { + if (a_min - stride_vecs.vec_a.begin() != b_min - stride_vecs.vec_b.begin()) { + throw mkl::unimplemented( + "dft/backends/cufft", __FUNCTION__, + "cufft requires that if ordered by stride length, the order of strides is the same for input and output strides!"); + } + } + const int a_stride = static_cast(*a_min); + const int b_stride = static_cast(*b_min); + stride_vecs.vec_a.erase(a_min); + stride_vecs.vec_b.erase(b_min); + int fwd_istride = a_stride; + int fwd_ostride = b_stride; + int bwd_istride = + stride_api_choice == dft::detail::stride_api::FB_STRIDES ? b_stride : a_stride; + int bwd_ostride = + stride_api_choice == dft::detail::stride_api::FB_STRIDES ? a_stride : b_stride; + if (a_min - stride_vecs.vec_a.begin() != rank) { + // swap dimensions to have the last one have the smallest stride + std::swap(n_copy[a_min - stride_vecs.vec_a.begin() - 1], n_copy[rank - 1]); + } + for (int i = 1; i < rank; i++) { + if ((stride_vecs.vec_a[i] % a_stride != 0) || (stride_vecs.vec_b[i] % b_stride != 0)) { + throw mkl::unimplemented( + "dft/backends/cufft", __FUNCTION__, + "cufft requires a stride to be divisible by all smaller strides!"); + } + stride_vecs.vec_a[i] /= a_stride; + stride_vecs.vec_b[i] /= b_stride; + } + if (rank > 2) { + if (stride_vecs.vec_a[1] > stride_vecs.vec_a[2] && + stride_vecs.vec_b[1] < stride_vecs.vec_b[2]) { + throw mkl::unimplemented( + "dft/backends/cufft", __FUNCTION__, + "cufft requires that if ordered by stride length, the order of strides is the same for input and output strides!"); + } + else if (stride_vecs.vec_a[1] < stride_vecs.vec_a[2] && + stride_vecs.vec_b[1] < stride_vecs.vec_b[2]) { + // swap dimensions to have the first one have the biggest stride + std::swap(stride_vecs.vec_a[1], stride_vecs.vec_a[2]); + std::swap(stride_vecs.vec_b[1], stride_vecs.vec_b[2]); + std::swap(n_copy[0], n_copy[1]); + } + if ((stride_vecs.vec_a[1] % stride_vecs.vec_a[2] != 0) || + (stride_vecs.vec_b[1] % stride_vecs.vec_b[2] != 0)) { + throw mkl::unimplemented( + "dft/backends/cufft", __FUNCTION__, + "cufft requires a stride to be divisible by all smaller strides!"); + } + stride_vecs.vec_a[1] /= stride_vecs.vec_a[2]; + stride_vecs.vec_b[1] /= stride_vecs.vec_b[2]; + } + const int batch = static_cast(config_values.number_of_transforms); + const int fwd_dist = static_cast(config_values.fwd_dist); + const int bwd_dist = static_cast(config_values.bwd_dist); + + // When creating real-complex descriptions, the strides will always be wrong for one of the directions. + // This is because the least significant dimension is symmetric. + // If the strides are invalid (too small to fit) then just don't bother creating the plan + auto check_stride_validity = [&](auto strides_fwd, auto strides_bwd) { + int inner_nfwd = n_copy[rank - 1]; // inner dimensions of DFT + // Complex data is stored conjugate even for real domains + int inner_nbwd = dom == dft::domain::REAL ? inner_nfwd / 2 + 1 : inner_nfwd; + int inner_sfwd = strides_fwd.back(); // inner strides of DFT + int inner_sbwd = strides_bwd.back(); + bool valid = true; + for (int r = 1; r < rank; ++r) { + valid = valid && (inner_nfwd <= inner_sfwd) && (inner_nbwd <= inner_sbwd); + inner_nfwd *= n_copy[rank - r - 1]; + inner_nbwd *= n_copy[rank - r - 1]; + inner_sfwd *= strides_fwd[rank - r - 1]; + inner_sbwd *= strides_bwd[rank - r - 1]; + } + return valid; + }; + + bool valid_forward = check_stride_validity(stride_vecs.fwd_in, stride_vecs.fwd_out); + bool valid_backward = stride_api_choice == dft::detail::stride_api::FB_STRIDES + ? valid_forward + : check_stride_validity(stride_vecs.bwd_out, stride_vecs.bwd_in); + + if (!valid_forward && !valid_backward) { + throw mkl::exception("dft/backends/cufft", __FUNCTION__, "Invalid strides."); + } + + if (valid_forward) { + cufftHandle fwd_plan; + auto res = cufftCreate(&fwd_plan); + if (res != CUFFT_SUCCESS) { + throw mkl::exception("dft/backends/cufft", __FUNCTION__, "cufftCreate failed."); + } + apply_external_workspace_setting(fwd_plan, config_values.workspace_placement); + res = cufftPlanMany(&fwd_plan, // plan + rank, // rank + n_copy.data(), // n + stride_vecs.fwd_in.data(), // inembed + fwd_istride, // istride + fwd_dist, // idist + stride_vecs.fwd_out.data(), // onembed + fwd_ostride, // ostride + bwd_dist, // odist + fwd_type, // type + batch // batch + ); + + if (res != CUFFT_SUCCESS) { + throw mkl::exception("dft/backends/cufft", __FUNCTION__, + "Failed to create forward cuFFT plan."); + } + + plans[0] = fwd_plan; + } + + if (valid_backward) { + cufftHandle bwd_plan; + auto res = cufftCreate(&bwd_plan); + if (res != CUFFT_SUCCESS) { + throw mkl::exception("dft/backends/cufft", __FUNCTION__, "cufftCreate failed."); + } + apply_external_workspace_setting(bwd_plan, config_values.workspace_placement); + // flip fwd_distance and bwd_distance because cuFFt uses input distance and output distance. + res = cufftPlanMany(&bwd_plan, // plan + rank, // rank + n_copy.data(), // n + stride_vecs.bwd_in.data(), // inembed + bwd_istride, // istride + bwd_dist, // idist + stride_vecs.bwd_out.data(), // onembed + bwd_ostride, // ostride + fwd_dist, // odist + bwd_type, // type + batch // batch + ); + if (res != CUFFT_SUCCESS) { + throw mkl::exception("dft/backends/cufft", __FUNCTION__, + "Failed to create backward cuFFT plan."); + } + plans[1] = bwd_plan; + } + } + + ~cufft_commit() override { + clean_plans(); + } + + static void apply_external_workspace_setting(cufftHandle handle, + config_value workspace_setting) { + if (workspace_setting == config_value::WORKSPACE_EXTERNAL) { + auto res = cufftSetAutoAllocation(handle, 0); + if (res != CUFFT_SUCCESS) { + throw mkl::exception("dft/backends/cufft", "commit", + "cufftSetAutoAllocation(plan, 0) failed."); + } + } + } + + void* get_handle() noexcept override { + return plans.data(); + } + + std::array get_offsets_fwd() noexcept { + return { offset_fwd_in, offset_fwd_out }; + } + + std::array get_offsets_bwd() noexcept { + return { offset_bwd_in, offset_bwd_out }; + } + + virtual void set_workspace(scalar_type* usm_workspace) override { + this->external_workspace_helper_.set_workspace_throw(*this, usm_workspace); + if (plans[0]) { + cufftSetWorkArea(*plans[0], usm_workspace); + } + if (plans[1]) { + cufftSetWorkArea(*plans[1], usm_workspace); + } + } + + void set_buffer_workspace(cufftHandle plan, sycl::buffer& buffer_workspace) { + this->get_queue() + .submit([&](sycl::handler& cgh) { + auto workspace_acc = + buffer_workspace.template get_access(cgh); + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = ih.get_native_queue(); + auto result = cufftSetStream(plan, stream); + if (result != CUFFT_SUCCESS) { + throw oneapi::mkl::exception( + "dft/backends/cufft", "set_workspace", + "cufftSetStream returned " + std::to_string(result)); + } + auto workspace_native = reinterpret_cast( + ih.get_native_mem(workspace_acc)); + cufftSetWorkArea(plan, workspace_native); + }); + }) + .wait_and_throw(); + } + + virtual void set_workspace(sycl::buffer& buffer_workspace) override { + this->external_workspace_helper_.set_workspace_throw(*this, buffer_workspace); + if (plans[0]) { + set_buffer_workspace(*plans[0], buffer_workspace); + } + if (plans[1]) { + set_buffer_workspace(*plans[1], buffer_workspace); + } + } + + std::int64_t get_plan_workspace_size_bytes(cufftHandle handle) { + std::size_t size = 0; + cufftGetSize(handle, &size); + std::int64_t padded_size = static_cast(size); + return padded_size; + } + + virtual std::int64_t get_workspace_external_bytes_impl() override { + std::int64_t size0 = plans[0] ? get_plan_workspace_size_bytes(*plans[0]) : 0; + std::int64_t size1 = plans[1] ? get_plan_workspace_size_bytes(*plans[1]) : 0; + return std::max(size0, size1); + }; + +#define BACKEND cufft +#include "../backend_compute_signature.cxx" +#undef BACKEND +}; +} // namespace detail + +template +dft::detail::commit_impl* create_commit(const dft::detail::descriptor& desc, + sycl::queue& sycl_queue) { + return new detail::cufft_commit(sycl_queue, desc.get_values()); +} + +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); + +namespace detail { +template +std::array get_offsets_fwd(dft::detail::commit_impl* commit) { + return static_cast*>(commit)->get_offsets_fwd(); +} + +template +std::array get_offsets_bwd(dft::detail::commit_impl* commit) { + return static_cast*>(commit)->get_offsets_bwd(); +} + +template std::array +get_offsets_fwd( + dft::detail::commit_impl*); +template std::array +get_offsets_fwd( + dft::detail::commit_impl*); +template std::array +get_offsets_fwd( + dft::detail::commit_impl*); +template std::array +get_offsets_fwd( + dft::detail::commit_impl*); + +template std::array +get_offsets_bwd( + dft::detail::commit_impl*); +template std::array +get_offsets_bwd( + dft::detail::commit_impl*); +template std::array +get_offsets_bwd( + dft::detail::commit_impl*); +template std::array +get_offsets_bwd( + dft::detail::commit_impl*); +} //namespace detail + +} // namespace oneapi::mkl::dft::cufft diff --git a/src/dft/backends/cufft/descriptor.cpp b/src/dft/backends/cufft/descriptor.cpp new file mode 100644 index 000000000..d102164c2 --- /dev/null +++ b/src/dft/backends/cufft/descriptor.cpp @@ -0,0 +1,49 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/descriptor.hpp" +#include "../../descriptor.cxx" + +#include "oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { + +template +void descriptor::commit(backend_selector selector) { + if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) { + if (pimpl_) { + pimpl_->get_queue().wait(); + } + pimpl_.reset(cufft::create_commit(*this, selector.get_queue())); + } + pimpl_->commit(values_); +} + +template void descriptor::commit( + backend_selector); +template void descriptor::commit(backend_selector); +template void descriptor::commit( + backend_selector); +template void descriptor::commit(backend_selector); + +} //namespace dft +} //namespace mkl +} //namespace oneapi diff --git a/src/dft/backends/cufft/execute_helper.hpp b/src/dft/backends/cufft/execute_helper.hpp new file mode 100644 index 000000000..776f0f254 --- /dev/null +++ b/src/dft/backends/cufft/execute_helper.hpp @@ -0,0 +1,148 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_CUFFT_EXECUTE_HPP_ +#define _ONEMKL_DFT_SRC_CUFFT_EXECUTE_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/exceptions.hpp" + +#include +#include + +namespace oneapi::mkl::dft::cufft::detail { + +template +inline dft::detail::commit_impl *checked_get_commit( + dft::detail::descriptor &desc) { + auto commit_handle = dft::detail::get_commit(desc); + if (commit_handle == nullptr || commit_handle->get_backend() != backend::cufft) { + throw mkl::invalid_argument("dft/backends/cufft", "get_commit", + "DFT descriptor has not been commited for cuFFT"); + } + return commit_handle; +} + +/// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match +/// the expected value. +template +inline auto expect_config(DescT &desc, const char *message) { + dft::config_value actual{ 0 }; + desc.get_value(Param, &actual); + if (actual != Expected) { + throw mkl::invalid_argument("dft/backends/cufft", "expect_config", message); + } +} + +enum class Direction { Forward = CUFFT_FORWARD, Backward = CUFFT_INVERSE }; + +template +void cufft_execute(const std::string &func, CUstream stream, cufftHandle plan, void *input, + void *output) { + constexpr bool is_real = std::is_floating_point_v; + using single_type = std::conditional_t>; + constexpr bool is_single = std::is_same_v; + + if constexpr (is_real) { + if constexpr (dir == Direction::Forward) { + if constexpr (is_single) { + auto result = cufftExecR2C(plan, reinterpret_cast(input), + reinterpret_cast(output)); + if (result != CUFFT_SUCCESS) { + throw oneapi::mkl::exception("dft/backends/cufft", func, + "cufftExecR2C returned " + std::to_string(result)); + } + } + else { + auto result = cufftExecD2Z(plan, reinterpret_cast(input), + reinterpret_cast(output)); + if (result != CUFFT_SUCCESS) { + throw oneapi::mkl::exception("dft/backends/cufft", func, + "cufftExecD2Z returned " + std::to_string(result)); + } + } + } + else { + if constexpr (is_single) { + auto result = cufftExecC2R(plan, reinterpret_cast(input), + reinterpret_cast(output)); + if (result != CUFFT_SUCCESS) { + throw oneapi::mkl::exception("dft/backends/cufft", func, + "cufftExecC2R returned " + std::to_string(result)); + } + } + else { + auto result = cufftExecZ2D(plan, reinterpret_cast(input), + reinterpret_cast(output)); + if (result != CUFFT_SUCCESS) { + throw oneapi::mkl::exception("dft/backends/cufft", func, + "cufftExecZ2D returned " + std::to_string(result)); + } + } + } + } + else { + if constexpr (is_single) { + auto result = + cufftExecC2C(plan, reinterpret_cast(input), + reinterpret_cast(output), static_cast(dir)); + if (result != CUFFT_SUCCESS) { + throw oneapi::mkl::exception("dft/backends/cufft", func, + "cufftExecC2C returned " + std::to_string(result)); + } + } + else { + auto result = + cufftExecZ2Z(plan, reinterpret_cast(input), + reinterpret_cast(output), static_cast(dir)); + if (result != CUFFT_SUCCESS) { + throw oneapi::mkl::exception("dft/backends/cufft", func, + "cufftExecZ2Z returned " + std::to_string(result)); + } + } + } + + auto result = cuStreamSynchronize(stream); + if (result != CUDA_SUCCESS) { + throw oneapi::mkl::exception("dft/backends/cufft", func, + "cuStreamSynchronize returned " + std::to_string(result)); + } +} + +inline CUstream setup_stream(const std::string &func, sycl::interop_handle ih, cufftHandle plan) { + auto stream = ih.get_native_queue(); + auto result = cufftSetStream(plan, stream); + if (result != CUFFT_SUCCESS) { + throw oneapi::mkl::exception("dft/backends/cufft", func, + "cufftSetStream returned " + std::to_string(result)); + } + return stream; +} + +} // namespace oneapi::mkl::dft::cufft::detail + +#endif diff --git a/src/dft/backends/cufft/forward.cpp b/src/dft/backends/cufft/forward.cpp new file mode 100644 index 000000000..fb323c085 --- /dev/null +++ b/src/dft/backends/cufft/forward.cpp @@ -0,0 +1,247 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/exceptions.hpp" + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp" +#include "oneapi/mkl/dft/types.hpp" + +#include "execute_helper.hpp" + +#include + +namespace oneapi::mkl::dft::cufft { + +namespace detail { +//forward declaration +template +std::array get_offsets_fwd(dft::detail::commit_impl *commit); + +template +cufftHandle get_fwd_plan(dft::detail::commit_impl *commit) { + return static_cast *>(commit->get_handle())[0].value(); +} +} // namespace detail + +// BUFFER version + +//In-place transform +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &inout) { + const std::string func_name = "compute_forward(desc, inout)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto offsets = detail::get_offsets_fwd(commit); + + if constexpr (std::is_floating_point_v>) { + if (offsets[0] % 2 != 0) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "cuFFT requires offset (first value in strides) to be multiple of 2!"); + } + offsets[1] *= 2; // offset is supplied in complex but we offset scalar pointer + } + + queue.submit([&](sycl::handler &cgh) { + auto inout_acc = inout.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, plan); + + auto inout_native = reinterpret_cast *>( + ih.get_native_mem(inout_acc)); + detail::cufft_execute>( + func_name, stream, plan, reinterpret_cast(inout_native + offsets[0]), + reinterpret_cast(inout_native + offsets[1])); + }); + }); +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_forward(descriptor_type &, sycl::buffer, 1> &, + sycl::buffer, 1> &) { + throw oneapi::mkl::unimplemented("DFT", "compute_forward(desc, inout_re, inout_im)", + "cuFFT does not support real-real complex storage."); +} + +//Out-of-place transform +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + const std::string func_name = "compute_forward(desc, in, out)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto offsets = detail::get_offsets_fwd(commit); + + if constexpr (std::is_floating_point_v>) { + if (offsets[0] % 2 != 0) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "cuFFT requires offset (first value in strides) to be multiple of 2!"); + } + } + + queue.submit([&](sycl::handler &cgh) { + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, plan); + + auto in_native = reinterpret_cast( + reinterpret_cast *>( + ih.get_native_mem(in_acc)) + + offsets[0]); + auto out_native = reinterpret_cast( + reinterpret_cast *>( + ih.get_native_mem(out_acc)) + + offsets[1]); + detail::cufft_execute>( + func_name, stream, plan, in_native, out_native); + }); + }); +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_forward(descriptor_type &, sycl::buffer, 1> &, + sycl::buffer, 1> &, + sycl::buffer, 1> &, + sycl::buffer, 1> &) { + throw oneapi::mkl::unimplemented("DFT", "compute_forward(desc, in_re, in_im, out_re, out_im)", + "cuFFT does not support real-real complex storage."); +} + +//USM version + +//In-place transform +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *inout, + const std::vector &dependencies) { + const std::string func_name = "compute_forward(desc, inout, dependencies)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto offsets = detail::get_offsets_fwd(commit); + + if constexpr (std::is_floating_point_v>) { + if (offsets[0] % 2 != 0) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "cuFFT requires offset (first value in strides) to be multiple of 2!"); + } + offsets[1] *= 2; // offset is supplied in complex but we offset scalar pointer + } + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, plan); + + detail::cufft_execute>( + func_name, stream, plan, inout + offsets[0], inout + offsets[1]); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &, scalar *, + scalar *, + const std::vector &) { + throw oneapi::mkl::unimplemented("DFT", + "compute_forward(desc, inout_re, inout_im, dependencies)", + "cuFFT does not support real-real complex storage."); +} + +//Out-of-place transform +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *in, + bwd *out, + const std::vector &dependencies) { + const std::string func_name = "compute_forward(desc, in, out, dependencies)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto offsets = detail::get_offsets_fwd(commit); + + if constexpr (std::is_floating_point_v>) { + if (offsets[0] % 2 != 0) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "cuFFT requires offset (first value in strides) to be multiple of 2!"); + } + } + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(dependencies); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, plan); + + detail::cufft_execute>( + func_name, stream, plan, in + offsets[0], out + offsets[1]); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &, scalar *, + scalar *, scalar *, + scalar *, + const std::vector &) { + throw oneapi::mkl::unimplemented( + "DFT", "compute_forward(desc, in_re, in_im, out_re, out_im, dependencies)", + "cuFFT does not support real-real complex storage."); +} + +// Template function instantiations +#include "dft/backends/backend_forward_instantiations.cxx" + +} // namespace oneapi::mkl::dft::cufft diff --git a/src/dft/backends/cufft/mkl_dft_cufft_wrappers.cpp b/src/dft/backends/cufft/mkl_dft_cufft_wrappers.cpp new file mode 100644 index 000000000..93d3aae11 --- /dev/null +++ b/src/dft/backends/cufft/mkl_dft_cufft_wrappers.cpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp" +#include "dft/function_table.hpp" + +#define WRAPPER_VERSION 1 +#define BACKEND cufft + +extern "C" dft_function_table_t mkl_dft_table = { + WRAPPER_VERSION, +#include "dft/backends/backend_wrappers.cxx" +}; + +#undef WRAPPER_VERSION +#undef BACKEND diff --git a/src/dft/backends/mklcpu/CMakeLists.txt b/src/dft/backends/mklcpu/CMakeLists.txt index f4fc4c016..6d0f1276d 100644 --- a/src/dft/backends/mklcpu/CMakeLists.txt +++ b/src/dft/backends/mklcpu/CMakeLists.txt @@ -20,8 +20,7 @@ set(LIB_NAME onemkl_dft_mklcpu) set(LIB_OBJ ${LIB_NAME}_obj) -set(USE_DPCPP_API ON) -find_package(MKL REQUIRED) +include(WarningsUtils) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT @@ -29,22 +28,41 @@ add_library(${LIB_OBJ} OBJECT descriptor.cpp forward.cpp backward.cpp - compute_signature.cpp $<$: mkl_dft_cpu_wrappers.cpp> ) +add_dependencies(onemkl_backend_libs_dft ${LIB_NAME}) target_include_directories(${LIB_OBJ} - PRIVATE ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/src + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) +target_include_directories(${LIB_NAME} + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/src ${CMAKE_BINARY_DIR}/bin - ${MKL_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT} -DBUILD_COMP) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES}) endif() -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) + +if(TARGET MKL::MKL_SYCL::DFT) + target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC MKL::MKL_SYCL::DFT + PRIVATE onemkl_warnings + ) +else() + target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC MKL::MKL_DPCPP + PRIVATE onemkl_warnings + ) +endif() set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/src/dft/backends/mklcpu/backward.cpp b/src/dft/backends/mklcpu/backward.cpp index 36bb41d9a..fe7186630 100644 --- a/src/dft/backends/mklcpu/backward.cpp +++ b/src/dft/backends/mklcpu/backward.cpp @@ -23,90 +23,308 @@ #include #endif -#include "oneapi/mkl/types.hpp" -#include "oneapi/mkl/dft/types.hpp" -#include "oneapi/mkl/detail/exceptions.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#include "oneapi/mkl/dft/types.hpp" + +#include "dft/backends/mklcpu/commit_derived_impl.hpp" -namespace oneapi { -namespace mkl { -namespace dft { -namespace mklcpu { +// MKLCPU header +#include "mkl_dfti.h" + +namespace oneapi::mkl::dft::mklcpu { +namespace detail { // BUFFER version +// backward a MKLCPU DFT call to the backend, checking that the commit impl is valid. +template +inline void check_bwd_commit(dft::descriptor &desc) { + auto commit_handle = dft::detail::get_commit(desc); + if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklcpu) { + throw mkl::invalid_argument("DFT", "computer_backward", + "DFT descriptor has not been commited for MKLCPU"); + } + + auto mklcpu_desc = reinterpret_cast(commit_handle->get_handle()); + MKL_LONG commit_status{ DFTI_UNCOMMITTED }; + DftiGetValue(mklcpu_desc[1], DFTI_COMMIT_STATUS, &commit_status); + if (commit_status != DFTI_COMMITTED) { + throw mkl::invalid_argument("DFT", "compute_backward", + "MKLCPU DFT descriptor was not successfully committed."); + } +} + +// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match +// the expected value. +template +inline auto expect_config(DescT &desc, const char *message) { + dft::detail::config_value actual{ 0 }; + desc.get_value(Param, &actual); + if (actual != Expected) { + throw mkl::invalid_argument("DFT", "compute_backward", message); + } +} +// convert the base commit class to derived cpu commit class +template +auto get_buffer(commit_t *commit_handle) { + commit_derived_t *derived_commit = + static_cast *>(commit_handle); + return derived_commit->get_handle_buffer(); +} +} // namespace detail //In-place transform -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &inout) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto inout_acc = inout.template get_access(cgh); + detail::host_task(cgh, [=]() { + DFT_ERROR status = + DftiComputeBackward(desc_acc[detail::DIR::bwd], detail::acc_to_ptr(inout_acc)); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &inout_re, + sycl::buffer, 1> &inout_im) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto re_acc = inout_re.template get_access(cgh); + auto im_acc = inout_im.template get_access(cgh); + + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward( + desc_acc[detail::DIR::bwd], detail::acc_to_ptr(re_acc), detail::acc_to_ptr(im_acc)); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in, - sycl::buffer &out) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + detail::expect_config(desc, + "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + + detail::host_task(cgh, [=]() { + auto in_ptr = const_cast *>(detail::acc_to_ptr(in_acc)); + DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], in_ptr, + detail::acc_to_ptr(out_acc)); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &in_re, + sycl::buffer, 1> &in_im, + sycl::buffer, 1> &out_re, + sycl::buffer, 1> &out_im) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto inre_acc = in_re.template get_access(cgh); + auto inim_acc = in_im.template get_access(cgh); + auto outre_acc = out_re.template get_access(cgh); + auto outim_acc = out_im.template get_access(cgh); + + detail::host_task(cgh, [=]() { + auto inre_ptr = const_cast *>(detail::acc_to_ptr(inre_acc)); + auto inim_ptr = const_cast *>(detail::acc_to_ptr(inim_acc)); + DFT_ERROR status = + DftiComputeBackward(desc_acc[detail::DIR::bwd], inre_ptr, inim_ptr, + detail::acc_to_ptr(outre_acc), detail::acc_to_ptr(outim_acc)); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //USM version //In-place transform -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout, +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd *inout, const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); - return sycl::event{}; + detail::expect_config( + desc, "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], inout); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, - data_type *inout_im, +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar *inout_re, + scalar *inout_im, const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); - return sycl::event{}; + detail::expect_config( + desc, "Unexpected value for complex storage"); + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], inout_re, inout_im); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd *in, + fwd *out, const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); - return sycl::event{}; + // Check: inplace, complex storage + detail::expect_config(desc, + "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeBackward(desc_acc[detail::DIR::bwd], in, out); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in_re, - input_type *in_im, output_type *out_re, - output_type *out_im, +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar *in_re, + scalar *in_im, + scalar *out_re, + scalar *out_im, const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_backward", "Not implemented for MKLCPU"); - return sycl::event{}; + detail::expect_config( + desc, "Unexpected value for complex storage"); + auto commit_handle = dft::detail::get_commit(desc); + detail::check_bwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = + DftiComputeBackward(desc_acc[detail::DIR::bwd], in_re, in_im, out_re, out_im); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "compute_backward", + std::string("DftiComputeBackward failed : ") + DftiErrorMessage(status)); + } + }); + }); } // Template function instantiations #include "dft/backends/backend_backward_instantiations.cxx" -} // namespace mklcpu -} // namespace dft -} // namespace mkl -} // namespace oneapi +} // namespace oneapi::mkl::dft::mklcpu diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index b510212eb..1ec8aef9c 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -31,6 +31,9 @@ #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" #include "oneapi/mkl/dft/detail/commit_impl.hpp" + +#include "dft/backends/mklcpu/commit_derived_impl.hpp" +#include "../stride_helper.hpp" #include "mkl_service.h" #include "mkl_dfti.h" @@ -38,106 +41,170 @@ namespace oneapi { namespace mkl { namespace dft { namespace mklcpu { +namespace detail { + +template +commit_derived_impl::commit_derived_impl( + sycl::queue queue, const dft::detail::dft_values& config_values) + : oneapi::mkl::dft::detail::commit_impl(queue, backend::mklcpu, config_values) { + // create the descriptor once for the lifetime of the descriptor class + DFT_ERROR status[2] = { DFTI_BAD_DESCRIPTOR, DFTI_BAD_DESCRIPTOR }; -template -class commit_derived_impl final : public detail::commit_impl { -public: - commit_derived_impl(sycl::queue queue, const detail::dft_values& config_values) - : detail::commit_impl(queue, backend::mklcpu) { - DFT_ERROR status = DFT_NOTSET; + for (auto dir : { DIR::fwd, DIR::bwd }) { + const auto rank = static_cast(config_values.dimensions.size()); if (config_values.dimensions.size() == 1) { - status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), 1, - config_values.dimensions[0]); + status[dir] = DftiCreateDescriptor(&bidirection_handle[dir], mklcpu_prec, mklcpu_dom, 1, + config_values.dimensions[0]); } else { - status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), - config_values.dimensions.size(), - config_values.dimensions.data()); - } - if (status != DFTI_NO_ERROR) { - throw oneapi::mkl::exception( - "dft/backends/mklcpu", "commit", - "DftiCreateDescriptor failed with status: " + std::to_string(status)); + status[dir] = DftiCreateDescriptor(&bidirection_handle[dir], mklcpu_prec, mklcpu_dom, + rank, config_values.dimensions.data()); } } - void commit(const detail::dft_values& config_values) override { - set_value(handle, config_values); - auto status = DftiCommitDescriptor(handle); - if (status != DFTI_NO_ERROR) { - throw oneapi::mkl::exception( - "dft/backends/mklcpu", "commit", - "DftiCommitDescriptor failed with status: " + std::to_string(status)); - } + if (status[0] != DFTI_NO_ERROR || status[1] != DFTI_NO_ERROR) { + std::string err = std::string("DftiCreateDescriptor failed with status : ") + + DftiErrorMessage(status[0]) + std::string(", ") + + DftiErrorMessage(status[1]); + throw oneapi::mkl::exception("dft/backends/mklcpu", "create_descriptor", err); } +} - virtual void* get_handle() noexcept override { - return handle; +template +commit_derived_impl::~commit_derived_impl() { + for (auto dir : { DIR::fwd, DIR::bwd }) { + DftiFreeDescriptor(&bidirection_handle[dir]); } +} - virtual ~commit_derived_impl() override { - DftiFreeDescriptor((DFTI_DESCRIPTOR_HANDLE*)&handle); - } +template +void commit_derived_impl::commit( + const dft::detail::dft_values& config_values) { + this->external_workspace_helper_ = + oneapi::mkl::dft::detail::external_workspace_helper( + config_values.workspace_placement == + oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL); + set_value(bidirection_handle.data(), config_values); + + this->get_queue() + .submit([&](sycl::handler& cgh) { + auto bidir_handle_obj = + bidirection_buffer.get_access(cgh); + + host_task>(cgh, [=]() { + DFT_ERROR status[2] = { DFTI_BAD_DESCRIPTOR, DFTI_BAD_DESCRIPTOR }; + + for (auto dir : { DIR::fwd, DIR::bwd }) + status[dir] = DftiCommitDescriptor(bidir_handle_obj[dir]); + + // this is important for real-batched transforms, as the backward transform would + // be inconsistent based on the stride setup, but once recommited before backward + // it should work just fine. so we error out only if there is a issue with both. + if (status[0] != DFTI_NO_ERROR && status[1] != DFTI_NO_ERROR) { + std::string err = std::string("DftiCommitDescriptor failed with status : ") + + DftiErrorMessage(status[0]) + std::string(", ") + + DftiErrorMessage(status[1]); + throw oneapi::mkl::exception("dft/backends/mklcpu", "commit", err); + } + }); + }) + .wait(); +} -private: - DFTI_DESCRIPTOR_HANDLE handle = nullptr; +template +void* commit_derived_impl::get_handle() noexcept { + return reinterpret_cast(bidirection_handle.data()); +} - constexpr DFTI_CONFIG_VALUE get_domain(domain d) { - if (d == domain::COMPLEX) { - return DFTI_COMPLEX; - } - else { - return DFTI_REAL; - } +template +template +void commit_derived_impl::set_value_item(mklcpu_desc_t hand, enum DFTI_CONFIG_PARAM name, + Args... args) { + DFT_ERROR value_err = DftiSetValue(hand, name, args...); + if (value_err != DFTI_NO_ERROR) { + throw oneapi::mkl::exception("dft/backends/mklcpu", "set_value_item", + DftiErrorMessage(value_err)); } +} - constexpr DFTI_CONFIG_VALUE get_precision(precision p) { - if (p == precision::SINGLE) { - return DFTI_SINGLE; +template +void commit_derived_impl::set_value(mklcpu_desc_t* descHandle, + const dft::detail::dft_values& config) { + auto stride_choice = dft::detail::get_stride_api(config); + dft::detail::throw_on_invalid_stride_api("MKLCPU commit", stride_choice); + for (auto dir : { DIR::fwd, DIR::bwd }) { + if (stride_choice == dft::detail::stride_api::IO_STRIDES) { + set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.input_strides.data()); + set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.output_strides.data()); } - else { - return DFTI_DOUBLE; + else { // Forward / backward strides + if (dir == DIR::fwd) { + set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.fwd_strides.data()); + set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.bwd_strides.data()); + } + else { + set_value_item(descHandle[dir], DFTI_INPUT_STRIDES, config.bwd_strides.data()); + set_value_item(descHandle[dir], DFTI_OUTPUT_STRIDES, config.fwd_strides.data()); + } } - } - - template - void set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name, Args... args) { - if (auto ret = DftiSetValue(hand, name, args...); ret != DFTI_NO_ERROR) { - throw oneapi::mkl::exception( - "dft/backends/mklcpu", "set_value_item", - "name: " + std::to_string(name) + " error: " + std::to_string(ret)); + set_value_item(descHandle[dir], DFTI_BACKWARD_SCALE, config.bwd_scale); + set_value_item(descHandle[dir], DFTI_FORWARD_SCALE, config.fwd_scale); + set_value_item(descHandle[dir], DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); + set_value_item(descHandle[dir], DFTI_INPUT_DISTANCE, + (dir == detail::DIR::fwd) ? config.fwd_dist : config.bwd_dist); + set_value_item(descHandle[dir], DFTI_OUTPUT_DISTANCE, + (dir == detail::DIR::fwd) ? config.bwd_dist : config.fwd_dist); + set_value_item(descHandle[dir], DFTI_COMPLEX_STORAGE, + to_mklcpu(config.complex_storage)); + set_value_item(descHandle[dir], DFTI_REAL_STORAGE, + to_mklcpu(config.real_storage)); + set_value_item(descHandle[dir], DFTI_CONJUGATE_EVEN_STORAGE, + to_mklcpu(config.conj_even_storage)); + set_value_item(descHandle[dir], DFTI_PLACEMENT, + to_mklcpu(config.placement)); + set_value_item(descHandle[dir], DFTI_PACKED_FORMAT, + to_mklcpu(config.packed_format)); + // Setting the workspace causes an FFT_INVALID_DESCRIPTOR. + if (config.workspace != config_value::ALLOW) { + throw mkl::invalid_argument("dft/backends/mklcpu", "commit", + "MKLCPU only supports workspace set to allow"); + } + // Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used: + if (config.ordering != dft::detail::config_value::ORDERED) { + throw mkl::invalid_argument("dft/backends/mklcpu", "commit", + "MKLCPU only supports ordered ordering."); + } + // Setting the transpose causes an FFT_INVALID_DESCRIPTOR. Check that default is used: + if (config.transpose != false) { + throw mkl::invalid_argument("dft/backends/mklcpu", "commit", + "MKLCPU only supports non-transposed."); } } +} +} // namespace detail - void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, - const detail::dft_values& config) { - set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); - set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); - set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); - set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); - set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); - set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); - set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); - set_value_item( - descHandle, DFTI_PLACEMENT, - (config.placement == config_value::INPLACE) ? DFTI_INPLACE : DFTI_NOT_INPLACE); - } -}; - -template -detail::commit_impl* create_commit(const descriptor& desc, - sycl::queue& sycl_queue) { - return new commit_derived_impl(sycl_queue, desc.get_values()); +template +dft::detail::commit_impl* create_commit(const dft::detail::descriptor& desc, + sycl::queue& sycl_queue) { + return new detail::commit_derived_impl(sycl_queue, desc.get_values()); } -template detail::commit_impl* create_commit( - const descriptor&, sycl::queue&); -template detail::commit_impl* create_commit( - const descriptor&, sycl::queue&); -template detail::commit_impl* create_commit( - const descriptor&, sycl::queue&); -template detail::commit_impl* create_commit( - const descriptor&, sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); } // namespace mklcpu } // namespace dft diff --git a/src/dft/backends/mklcpu/commit_derived_impl.hpp b/src/dft/backends/mklcpu/commit_derived_impl.hpp new file mode 100644 index 000000000..3551758a0 --- /dev/null +++ b/src/dft/backends/mklcpu/commit_derived_impl.hpp @@ -0,0 +1,88 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_COMMIT_DERIVED_IMPL_HPP_ +#define _ONEMKL_DFT_COMMIT_DERIVED_IMPL_HPP_ + +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" +#include "dft/backends/mklcpu/mklcpu_helpers.hpp" + +// MKLCPU header +#include "mkl_dfti.h" + +namespace oneapi { +namespace mkl { +namespace dft { +namespace mklcpu { +namespace detail { + +// this is used for indexing bidirectional_handle +enum DIR { fwd = 0, bwd = 1 }; + +template +class commit_derived_impl final : public dft::detail::commit_impl { +private: + using scalar_type = typename dft::detail::commit_impl::scalar_type; + static constexpr DFTI_CONFIG_VALUE mklcpu_prec = to_mklcpu(prec); + static constexpr DFTI_CONFIG_VALUE mklcpu_dom = to_mklcpu(dom); + using mklcpu_desc_t = DFTI_DESCRIPTOR_HANDLE; + +public: + commit_derived_impl(sycl::queue queue, const dft::detail::dft_values& config_values); + + virtual void commit(const dft::detail::dft_values& config_values) override; + + virtual void* get_handle() noexcept override; + + virtual ~commit_derived_impl() override; + + sycl::buffer get_handle_buffer() noexcept { + return bidirection_buffer; + }; + +#define BACKEND mklcpu +#include "../backend_compute_signature.cxx" +#undef BACKEND + +private: + // bidirectional_handle[0] is the forward handle, bidirectional_handle[1] is the backward handle + std::array bidirection_handle{ nullptr, nullptr }; + sycl::buffer bidirection_buffer{ bidirection_handle.data(), + sycl::range<1>{ 2 } }; + + template + void set_value_item(mklcpu_desc_t hand, enum DFTI_CONFIG_PARAM name, Args... args); + + void set_value(mklcpu_desc_t* descHandle, const dft::detail::dft_values& config); +}; + +template +using commit_t = dft::detail::commit_impl; + +template +using commit_derived_t = detail::commit_derived_impl; + +} // namespace detail +} // namespace mklcpu +} // namespace dft +} // namespace mkl +} // namespace oneapi + +#endif // _ONEMKL_DFT_COMMIT_DERIVED_IMPL_HPP_ diff --git a/src/dft/backends/mklcpu/descriptor.cpp b/src/dft/backends/mklcpu/descriptor.cpp index b981d530f..2bb0e2835 100644 --- a/src/dft/backends/mklcpu/descriptor.cpp +++ b/src/dft/backends/mklcpu/descriptor.cpp @@ -32,7 +32,7 @@ void descriptor::commit(backend_selector selector) { if (pimpl_) { pimpl_->get_queue().wait(); } - pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue())); + pimpl_.reset(mklcpu::create_commit(*this, selector.get_queue())); } pimpl_->commit(values_); } diff --git a/src/dft/backends/mklcpu/forward.cpp b/src/dft/backends/mklcpu/forward.cpp index a73ea1b62..2e5e2fa88 100644 --- a/src/dft/backends/mklcpu/forward.cpp +++ b/src/dft/backends/mklcpu/forward.cpp @@ -23,88 +23,314 @@ #include #endif -#include "oneapi/mkl/types.hpp" -#include "oneapi/mkl/dft/types.hpp" -#include "oneapi/mkl/detail/exceptions.hpp" +#include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/dft/descriptor.hpp" #include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#include "oneapi/mkl/dft/types.hpp" + +#include "dft/backends/mklcpu/commit_derived_impl.hpp" + +// MKLCPU header +#include "mkl_dfti.h" + +namespace oneapi::mkl::dft::mklcpu { +namespace detail { + +// BUFFER version +// Forward a MKLCPU DFT call to the backend, checking that the commit impl is valid. +template +inline void check_fwd_commit(dft::descriptor &desc) { + auto commit_handle = dft::detail::get_commit(desc); + if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklcpu) { + throw mkl::invalid_argument("DFT", "computer_forward", + "DFT descriptor has not been commited for MKLCPU"); + } -namespace oneapi { -namespace mkl { -namespace dft { -namespace mklcpu { + auto mklcpu_desc = reinterpret_cast(commit_handle->get_handle()); + MKL_LONG commit_status{ DFTI_UNCOMMITTED }; + DftiGetValue(mklcpu_desc[0], DFTI_COMMIT_STATUS, &commit_status); + if (commit_status != DFTI_COMMITTED) { + throw mkl::invalid_argument("DFT", "compute_forward", + "MKLCPU DFT descriptor was not successfully committed."); + } +} + +// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match +// the expected value. +template +inline auto expect_config(DescT &desc, const char *message) { + dft::detail::config_value actual{ 0 }; + desc.get_value(Param, &actual); + if (actual != Expected) { + throw mkl::invalid_argument("DFT", "compute_forward", message); + } +} + +// convert the base commit class to derived cpu commit class +template +auto get_buffer(commit_t *commit_handle) { + commit_derived_t *derived_commit = + static_cast *>(commit_handle); + return derived_commit->get_handle_buffer(); +} +} // namespace detail //In-place transform -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &inout) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto inout_acc = inout.template get_access(cgh); + detail::host_task(cgh, [=]() { + DFT_ERROR status = + DftiComputeForward(desc_acc[detail::DIR::fwd], detail::acc_to_ptr(inout_acc)); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &inout_re, + sycl::buffer, 1> &inout_im) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto re_acc = inout_re.template get_access(cgh); + auto im_acc = inout_im.template get_access(cgh); + + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward( + desc_acc[detail::DIR::fwd], detail::acc_to_ptr(re_acc), detail::acc_to_ptr(im_acc)); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in, - sycl::buffer &out) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + detail::expect_config(desc, + "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + + detail::host_task(cgh, [=]() { + auto in_ptr = const_cast *>(detail::acc_to_ptr(in_acc)); + DFT_ERROR status = + DftiComputeForward(desc_acc[detail::DIR::fwd], in_ptr, detail::acc_to_ptr(out_acc)); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &in_re, + sycl::buffer, 1> &in_im, + sycl::buffer, 1> &out_re, + sycl::buffer, 1> &out_im) { + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + auto inre_acc = in_re.template get_access(cgh); + auto inim_acc = in_im.template get_access(cgh); + auto outre_acc = out_re.template get_access(cgh); + auto outim_acc = out_im.template get_access(cgh); + + detail::host_task(cgh, [=]() { + auto inre_ptr = const_cast *>(detail::acc_to_ptr(inre_acc)); + auto inim_ptr = const_cast *>(detail::acc_to_ptr(inim_acc)); + DFT_ERROR status = + DftiComputeForward(desc_acc[detail::DIR::fwd], inre_ptr, inim_ptr, + detail::acc_to_ptr(outre_acc), detail::acc_to_ptr(outim_acc)); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //USM version //In-place transform -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout, +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *inout, const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); - return sycl::event{}; + detail::expect_config( + desc, "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], inout); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, - data_type *inout_im, +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalar *inout_re, + scalar *inout_im, const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); - return sycl::event{}; + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], inout_re, inout_im); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out, +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *in, + bwd *out, const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); - return sycl::event{}; + // Check: inplace + detail::expect_config(desc, + "Unexpected value for placement"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = DftiComputeForward(desc_acc[detail::DIR::fwd], in, out); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in_re, - input_type *in_im, output_type *out_re, - output_type *out_im, +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalar *in_re, + scalar *in_im, + scalar *out_re, + scalar *out_im, const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_forward", "Not implemented for MKLCPU"); - return sycl::event{}; + detail::expect_config( + desc, "Unexpected value for complex storage"); + + auto commit_handle = dft::detail::get_commit(desc); + detail::check_fwd_commit(desc); + sycl::queue &cpu_queue{ commit_handle->get_queue() }; + + auto mklcpu_desc_buffer{ detail::get_buffer(commit_handle) }; + + return cpu_queue.submit([&](sycl::handler &cgh) { + auto desc_acc = mklcpu_desc_buffer.template get_access(cgh); + + cgh.depends_on(dependencies); + detail::host_task(cgh, [=]() { + DFT_ERROR status = + DftiComputeForward(desc_acc[detail::DIR::fwd], in_re, in_im, out_re, out_im); + if (status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/forward/mklcpu", "compute_forward", + std::string("DftiComputeForward failed : ") + DftiErrorMessage(status)); + } + }); + }); } // Template function instantiations #include "dft/backends/backend_forward_instantiations.cxx" -} // namespace mklcpu -} // namespace dft -} // namespace mkl -} // namespace oneapi +} // namespace oneapi::mkl::dft::mklcpu diff --git a/src/dft/backends/mklcpu/mklcpu_helpers.hpp b/src/dft/backends/mklcpu/mklcpu_helpers.hpp new file mode 100644 index 000000000..55a8345c2 --- /dev/null +++ b/src/dft/backends/mklcpu/mklcpu_helpers.hpp @@ -0,0 +1,178 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_MKLCPU_HELPERS_HPP_ +#define _ONEMKL_DFT_SRC_MKLCPU_HELPERS_HPP_ + +#include "oneapi/mkl/exceptions.hpp" +#include "oneapi/mkl/dft/detail/types_impl.hpp" + +// MKLCPU header +#include "mkl_dfti.h" + +namespace oneapi::mkl::dft::mklcpu::detail { + +template +static inline auto host_task_internal(H& cgh, F f, int) -> decltype(cgh.host_task(f)) { + return cgh.host_task(f); +} + +template +static inline void host_task(H& cgh, F f) { + (void)host_task_internal(cgh, f, 0); +} + +template +class kernel_name {}; + +/// Convert domain to equivalent backend native value. +inline constexpr DFTI_CONFIG_VALUE to_mklcpu(dft::detail::domain dom) { + if (dom == dft::detail::domain::REAL) { + return DFTI_REAL; + } + else { + return DFTI_COMPLEX; + } +} + +/// Convert precision to equivalent backend native value. +inline constexpr DFTI_CONFIG_VALUE to_mklcpu(dft::detail::precision dom) { + if (dom == dft::detail::precision::SINGLE) { + return DFTI_SINGLE; + } + else { + return DFTI_DOUBLE; + } +} + +/// Convert a config_param to equivalent backend native value. +inline constexpr DFTI_CONFIG_PARAM to_mklcpu(dft::detail::config_param param) { + using iparam = dft::detail::config_param; + switch (param) { + case iparam::FORWARD_DOMAIN: return DFTI_FORWARD_DOMAIN; + case iparam::DIMENSION: return DFTI_DIMENSION; + case iparam::LENGTHS: return DFTI_LENGTHS; + case iparam::PRECISION: return DFTI_PRECISION; + case iparam::FORWARD_SCALE: return DFTI_FORWARD_SCALE; + case iparam::NUMBER_OF_TRANSFORMS: return DFTI_NUMBER_OF_TRANSFORMS; + case iparam::COMPLEX_STORAGE: return DFTI_COMPLEX_STORAGE; + case iparam::REAL_STORAGE: return DFTI_REAL_STORAGE; + case iparam::CONJUGATE_EVEN_STORAGE: return DFTI_CONJUGATE_EVEN_STORAGE; + case iparam::FWD_DISTANCE: return DFTI_FWD_DISTANCE; + case iparam::BWD_DISTANCE: return DFTI_BWD_DISTANCE; + case iparam::WORKSPACE: return DFTI_WORKSPACE; + case iparam::ORDERING: return DFTI_ORDERING; + case iparam::TRANSPOSE: return DFTI_TRANSPOSE; + case iparam::PACKED_FORMAT: return DFTI_PACKED_FORMAT; + case iparam::COMMIT_STATUS: return DFTI_COMMIT_STATUS; + default: + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config param."); + return static_cast(0); + } +} + +/** Convert a config_value to the backend's native value. Throw on invalid input. + * @tparam Param The config param the value is for. + * @param value The config value to convert. +**/ +template +inline constexpr int to_mklcpu(dft::detail::config_value value); + +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::COMPLEX_COMPLEX) { + return DFTI_COMPLEX_COMPLEX; + } + else if (value == dft::detail::config_value::REAL_REAL) { + return DFTI_REAL_REAL; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for complex storage."); + return 0; + } +} + +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::REAL_REAL) { + return DFTI_REAL_REAL; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for real storage."); + return 0; + } +} +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::COMPLEX_COMPLEX) { + return DFTI_COMPLEX_COMPLEX; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for conjugate even storage."); + return 0; + } +} + +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::INPLACE) { + return DFTI_INPLACE; + } + else if (value == dft::detail::config_value::NOT_INPLACE) { + return DFTI_NOT_INPLACE; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for inplace."); + return 0; + } +} + +template <> +inline constexpr int to_mklcpu( + dft::detail::config_value value) { + if (value == dft::detail::config_value::CCE_FORMAT) { + return DFTI_CCE_FORMAT; + } + else { + throw mkl::invalid_argument("dft", "MKLCPU descriptor set_value()", + "Invalid config value for packed format."); + return 0; + } +} + +using mklcpu_desc_t = DFTI_DESCRIPTOR_HANDLE; + +template +typename AccType::value_type* acc_to_ptr(AccType acc) { + // no need to decorate the pointer with the address space for mklcpu since its just getting passed to the a host function. + return acc.template get_multi_ptr().get(); +} + +} // namespace oneapi::mkl::dft::mklcpu::detail + +#endif // _ONEMKL_DFT_SRC_MKLCPU_HELPERS_HPP_ diff --git a/src/dft/backends/mklgpu/CMakeLists.txt b/src/dft/backends/mklgpu/CMakeLists.txt index c60baf843..7e88a23d9 100644 --- a/src/dft/backends/mklgpu/CMakeLists.txt +++ b/src/dft/backends/mklgpu/CMakeLists.txt @@ -20,7 +20,7 @@ set(LIB_NAME onemkl_dft_mklgpu) set(LIB_OBJ ${LIB_NAME}_obj) -find_package(MKL REQUIRED) +include(WarningsUtils) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT @@ -28,20 +28,37 @@ add_library(${LIB_OBJ} OBJECT commit.cpp forward.cpp backward.cpp - compute_signature.cpp $<$: mkl_dft_gpu_wrappers.cpp> ) +add_dependencies(onemkl_backend_libs_dft ${LIB_NAME}) target_include_directories(${LIB_OBJ} - PRIVATE ${PROJECT_SOURCE_DIR}/include - ${PROJECT_SOURCE_DIR}/src + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) +target_include_directories(${LIB_NAME} + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/src ${CMAKE_BINARY_DIR}/bin - ${MKL_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) +if(TARGET MKL::MKL_SYCL::DFT) + target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC MKL::MKL_SYCL::DFT + PRIVATE onemkl_warnings + ) +else() + target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC MKL::MKL_DPCPP + PRIVATE onemkl_warnings + ) +endif() set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/src/dft/backends/mklgpu/backward.cpp b/src/dft/backends/mklgpu/backward.cpp index adb819f2a..6c4896c66 100644 --- a/src/dft/backends/mklgpu/backward.cpp +++ b/src/dft/backends/mklgpu/backward.cpp @@ -23,21 +23,17 @@ #include #endif -#include "oneapi/mkl/types.hpp" #include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" -#include "oneapi/mkl/dft/detail/types_impl.hpp" #include "oneapi/mkl/dft/detail/descriptor_impl.hpp" -#include "dft/backends/mklgpu/mklgpu_helpers.hpp" + +#include "mklgpu_helpers.hpp" // MKLGPU header #include "oneapi/mkl/dfti.hpp" -namespace oneapi { -namespace mkl { -namespace dft { -namespace mklgpu { +namespace oneapi::mkl::dft::mklgpu { namespace detail { /// Forward a MKLGPU DFT call to the backend, checking that the commit impl is valid. @@ -45,12 +41,15 @@ namespace detail { template inline auto compute_backward(dft::detail::descriptor &desc, ArgTs &&... args) { using mklgpu_desc_t = dft::descriptor; + using desc_shptr_t = std::shared_ptr; + using handle_t = std::pair; auto commit_handle = dft::detail::get_commit(desc); if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklgpu) { throw mkl::invalid_argument("DFT", "compute_backward", "DFT descriptor has not been commited for MKLGPU"); } - auto mklgpu_desc = reinterpret_cast(commit_handle->get_handle()); + auto handle = reinterpret_cast(commit_handle->get_handle()); + auto mklgpu_desc = handle->second; // Second because backward DFT. int commit_status{ DFTI_UNCOMMITTED }; mklgpu_desc->get_value(dft::config_param::COMMIT_STATUS, &commit_status); if (commit_status != DFTI_COMMITTED) { @@ -78,25 +77,28 @@ inline auto expect_config(DescT &desc, const char *message) { // BUFFER version //In-place transform -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout) { +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &inout) { detail::expect_config( desc, "Unexpected value for placement"); return detail::compute_backward(desc, inout); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +template +ONEMKL_EXPORT void compute_backward(descriptor_type & /*desc*/, + sycl::buffer, 1> & /*inout_re*/, + sycl::buffer, 1> & /*inout_im*/) { throw mkl::unimplemented("DFT", "compute_backward", "MKLGPU does not support compute_backward(desc, inout_re, inout_im)."); } //Out-of-place transform -template -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in, - sycl::buffer &out) { +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { detail::expect_config(desc, "Unexpected value for placement"); @@ -104,11 +106,12 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer -ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> & /*in_re*/, + sycl::buffer, 1> & /*in_im*/, + sycl::buffer, 1> & /*out_re*/, + sycl::buffer, 1> & /*out_im*/) { detail::expect_config( desc, "Unexpected value for complex storage"); @@ -120,8 +123,8 @@ ONEMKL_EXPORT void compute_backward(descriptor_type &desc, sycl::buffer -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout, +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd *inout, const std::vector &dependencies) { detail::expect_config( desc, "Unexpected value for placement"); @@ -129,17 +132,20 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *ino } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, data_type *inout_re, - data_type *inout_im, - const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_backward", - "MKLGPU does not support compute_backward(desc, inout_re, inout_im)."); +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type & /*desc*/, + scalar * /*inout_re*/, + scalar * /*inout_im*/, + const std::vector & /*dependencies*/) { + throw mkl::unimplemented( + "DFT", "compute_backward", + "MKLGPU does not support compute_backward(desc, inout_re, inout_im, dependencies)."); } //Out-of-place transform -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in, output_type *out, +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd *in, + fwd *out, const std::vector &dependencies) { detail::expect_config(desc, @@ -148,11 +154,13 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in_re, - input_type *in_im, output_type *out_re, - output_type *out_im, - const std::vector &dependencies) { +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, + scalar * /*in_re*/, + scalar * /*in_im*/, + scalar * /*out_re*/, + scalar * /*out_im*/, + const std::vector & /*dependencies*/) { detail::expect_config( desc, "Unexpected value for complex storage"); @@ -164,7 +172,4 @@ ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, input_type *in // Template function instantiations #include "dft/backends/backend_backward_instantiations.cxx" -} // namespace mklgpu -} // namespace dft -} // namespace mkl -} // namespace oneapi +} // namespace oneapi::mkl::dft::mklgpu diff --git a/src/dft/backends/mklgpu/commit.cpp b/src/dft/backends/mklgpu/commit.cpp index 897e0ae5b..d3a3f1cd6 100644 --- a/src/dft/backends/mklgpu/commit.cpp +++ b/src/dft/backends/mklgpu/commit.cpp @@ -33,35 +33,49 @@ #include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" #include "dft/backends/mklgpu/mklgpu_helpers.hpp" +#include "../stride_helper.hpp" // MKLGPU header #include "oneapi/mkl/dfti.hpp" +// MKL 2024.1 deprecates input/output strides. +#include "mkl_version.h" +#if INTEL_MKL_VERSION < 20240001 +#error MKLGPU requires oneMKL 2024.1 or later +#endif + /** Note that in this file, the Intel oneMKL closed-source library's interface mirrors the interface of this OneMKL open-source library. Consequently, the types under dft::TYPE are closed-source oneMKL types, and types under dft::detail::TYPE are from this library. **/ -namespace oneapi { -namespace mkl { -namespace dft { -namespace mklgpu { +namespace oneapi::mkl::dft::mklgpu { namespace detail { /// Commit impl class specialization for MKLGPU. template -class commit_derived_impl final : public dft::detail::commit_impl { +class mklgpu_commit final : public dft::detail::commit_impl { private: // Equivalent MKLGPU precision and domain from OneMKL's precision / domain. static constexpr dft::precision mklgpu_prec = to_mklgpu(prec); static constexpr dft::domain mklgpu_dom = to_mklgpu(dom); + + // A pair of descriptors are needed because of the [[deprecated]]IN/OUTPUT_STRIDES vs F/BWD_STRIDES API. + // Of the pair [0] is fwd DFT, [1] is backward DFT. If possible, the pointers refer to the same desciptor. + // Both pointers must be valid. using mklgpu_descriptor_t = dft::descriptor; + using descriptor_shptr_t = std::shared_ptr; + using handle_t = std::pair; + + using scalar_type = typename dft::detail::commit_impl::scalar_type; public: - commit_derived_impl(sycl::queue queue, const dft::detail::dft_values& config_values) - : oneapi::mkl::dft::detail::commit_impl(queue, backend::mklgpu), - handle(config_values.dimensions) { + mklgpu_commit(sycl::queue queue, const dft::detail::dft_values& config_values) + : oneapi::mkl::dft::detail::commit_impl(queue, backend::mklgpu, + config_values), + handle(std::make_shared(config_values.dimensions), nullptr) { + handle.second = handle.first; // Make sure the bwd pointer is valid. // MKLGPU does not throw an informative exception for the following: if constexpr (prec == dft::detail::precision::DOUBLE) { if (!queue.get_device().has(sycl::aspect::fp64)) { @@ -72,13 +86,51 @@ class commit_derived_impl final : public dft::detail::commit_impl { } virtual void commit(const dft::detail::dft_values& config_values) override { - set_value(handle, config_values); + this->external_workspace_helper_ = + oneapi::mkl::dft::detail::external_workspace_helper( + config_values.workspace_placement == + oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL); + + auto stride_choice = dft::detail::get_stride_api(config_values); + throw_on_invalid_stride_api("MKLGPU commit", stride_choice); + // A separate descriptor for each direction may not be required. + bool one_descriptor = (stride_choice == dft::detail::stride_api::FB_STRIDES) || + (config_values.input_strides == config_values.output_strides); + bool forward_good = true; + // Make sure that second is always pointing to something new if this is a recommit. + handle.second = handle.first; + + // Generate forward DFT descriptor. If using FWD/BWD_STRIDES API, only + // one descriptor is needed. + set_value(*handle.first, config_values, true, stride_choice); try { - handle.commit(this->get_queue()); + handle.first->commit(this->get_queue()); } catch (const std::exception& mkl_exception) { - // Catching the real MKL exception causes headaches with naming. - throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what()); + // Catching the real Intel oneMKL exception causes headaches with naming + forward_good = false; + if (one_descriptor) { + throw mkl::exception("dft/backends/mklgpu" + "commit", + mkl_exception.what()); + } + } + + // Generate backward DFT descriptor only if required. + if (!one_descriptor) { + handle.second = std::make_shared(config_values.dimensions); + set_value(*handle.second, config_values, false, stride_choice); + try { + handle.second->commit(this->get_queue()); + } + catch (const std::exception& mkl_exception) { + // Catching the real Intel oneMKL exception causes headaches with naming. + if (!forward_good) { + throw mkl::exception("dft/backends/mklgpu" + "commit", + mkl_exception.what()); + } + } } } @@ -86,13 +138,34 @@ class commit_derived_impl final : public dft::detail::commit_impl { return &handle; } - ~commit_derived_impl() override = default; + ~mklgpu_commit() override = default; + + virtual void set_workspace(scalar_type* usm_workspace) override { + this->external_workspace_helper_.set_workspace_throw(*this, usm_workspace); + handle.first->set_workspace(usm_workspace); + if (handle.first != handle.second) { + handle.second->set_workspace(usm_workspace); + } + } + + virtual void set_workspace(sycl::buffer& buffer_workspace) override { + this->external_workspace_helper_.set_workspace_throw(*this, buffer_workspace); + handle.first->set_workspace(buffer_workspace); + if (handle.first != handle.second) { + handle.second->set_workspace(buffer_workspace); + } + } + +#define BACKEND mklgpu +#include "../backend_compute_signature.cxx" +#undef BACKEND private: // The native MKLGPU class. - mklgpu_descriptor_t handle; + handle_t handle; - void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values& config) { + void set_value(mklgpu_descriptor_t& desc, const dft::detail::dft_values& config, + bool assume_fwd_dft, dft::detail::stride_api stride_choice) { using onemkl_param = dft::detail::config_param; using backend_param = dft::config_param; @@ -112,11 +185,37 @@ class commit_derived_impl final : public dft::detail::commit_impl { to_mklgpu(config.conj_even_storage)); desc.set_value(backend_param::PLACEMENT, to_mklgpu(config.placement)); - desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data()); - desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data()); + + if (stride_choice == dft::detail::stride_api::FB_STRIDES) { + if (config.fwd_strides[0] != 0 || config.fwd_strides[0] != 0) { + throw mkl::unimplemented("dft/backends/mklgpu", "commit", + "MKLGPU does not support nonzero offsets."); + } + desc.set_value(backend_param::FWD_STRIDES, config.fwd_strides.data()); + desc.set_value(backend_param::BWD_STRIDES, config.bwd_strides.data()); + } + else { + if (config.input_strides[0] != 0 || config.output_strides[0] != 0) { + throw mkl::unimplemented("dft/backends/mklgpu", "commit", + "MKLGPU does not support nonzero offsets."); + } + if (assume_fwd_dft) { + desc.set_value(backend_param::FWD_STRIDES, config.input_strides.data()); + desc.set_value(backend_param::BWD_STRIDES, config.output_strides.data()); + } + else { + desc.set_value(backend_param::FWD_STRIDES, config.output_strides.data()); + desc.set_value(backend_param::BWD_STRIDES, config.input_strides.data()); + } + } desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist); desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist); - // Setting the workspace causes an FFT_INVALID_DESCRIPTOR. + if (config.workspace_placement == dft::detail::config_value::WORKSPACE_EXTERNAL) { + // Setting WORKSPACE_INTERNAL (default) causes FFT_INVALID_DESCRIPTOR. + desc.set_value(backend_param::WORKSPACE, + to_mklgpu_config_value( + config.workspace_placement)); + } // Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used: if (config.ordering != dft::detail::config_value::ORDERED) { throw mkl::invalid_argument("dft/backends/mklgpu", "commit", @@ -128,13 +227,21 @@ class commit_derived_impl final : public dft::detail::commit_impl { "MKLGPU only supports non-transposed."); } } + + // This is called by the workspace_helper, and is not part of the user API. + virtual std::int64_t get_workspace_external_bytes_impl() override { + std::size_t workspaceSizeFwd = 0, workspaceSizeBwd = 0; + handle.first->get_value(dft::config_param::WORKSPACE_BYTES, &workspaceSizeFwd); + handle.second->get_value(dft::config_param::WORKSPACE_BYTES, &workspaceSizeBwd); + return static_cast(std::max(workspaceSizeFwd, workspaceSizeFwd)); + } }; } // namespace detail template dft::detail::commit_impl* create_commit(const dft::detail::descriptor& desc, sycl::queue& sycl_queue) { - return new detail::commit_derived_impl(sycl_queue, desc.get_values()); + return new detail::mklgpu_commit(sycl_queue, desc.get_values()); } template dft::detail::commit_impl* @@ -154,7 +261,4 @@ create_commit( const dft::detail::descriptor&, sycl::queue&); -} // namespace mklgpu -} // namespace dft -} // namespace mkl -} // namespace oneapi +} // namespace oneapi::mkl::dft::mklgpu diff --git a/src/dft/backends/mklgpu/forward.cpp b/src/dft/backends/mklgpu/forward.cpp index 5185cdbbf..39da42e45 100644 --- a/src/dft/backends/mklgpu/forward.cpp +++ b/src/dft/backends/mklgpu/forward.cpp @@ -24,40 +24,39 @@ #include #endif -#include "oneapi/mkl/types.hpp" #include "oneapi/mkl/exceptions.hpp" #include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" -#include "oneapi/mkl/dft/detail/types_impl.hpp" #include "oneapi/mkl/dft/detail/descriptor_impl.hpp" -#include "dft/backends/mklgpu/mklgpu_helpers.hpp" +#include "mklgpu_helpers.hpp" // MKLGPU header #include "oneapi/mkl/dfti.hpp" /** -Note that in this file, the Intel MKL-GPU library's interface mirrors the interface -of this OneMKL library. Consequently, the types under dft::TYPE are closed-source MKL types, -and types under dft::detail::TYPE are from this library. +Note that in this file, the Intel oneMKL-GPU library's interface mirrors the +interface of this OneMKL library. Consequently, the types under dft::TYPE are +closed-source Intel oneMKL types, and types under dft::detail::TYPE are from +this library. **/ -namespace oneapi { -namespace mkl { -namespace dft { -namespace mklgpu { +namespace oneapi::mkl::dft::mklgpu { namespace detail { /// Forward a MKLGPU DFT call to the backend, checking that the commit impl is valid. /// Assumes backend descriptor values match those of the frontend. template inline auto compute_forward(dft::detail::descriptor &desc, ArgTs &&... args) { using mklgpu_desc_t = dft::descriptor; + using desc_shptr_t = std::shared_ptr; + using handle_t = std::pair; auto commit_handle = dft::detail::get_commit(desc); if (commit_handle == nullptr || commit_handle->get_backend() != backend::mklgpu) { throw mkl::invalid_argument("DFT", "compute_forward", "DFT descriptor has not been commited for MKLGPU"); } - auto mklgpu_desc = reinterpret_cast(commit_handle->get_handle()); + auto handle = reinterpret_cast(commit_handle->get_handle()); + auto mklgpu_desc = handle->first; // First because forward DFT. int commit_status{ DFTI_UNCOMMITTED }; mklgpu_desc->get_value(dft::config_param::COMMIT_STATUS, &commit_status); if (commit_status != DFTI_COMMITTED) { @@ -85,25 +84,27 @@ inline auto expect_config(DescT &desc, const char *message) { // BUFFER version //In-place transform -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout) { +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &inout) { detail::expect_config( desc, "Unexpected value for placement"); return detail::compute_forward(desc, inout); } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &inout_re, - sycl::buffer &inout_im) { +template +ONEMKL_EXPORT void compute_forward(descriptor_type & /*desc*/, + sycl::buffer, 1> & /*inout_re*/, + sycl::buffer, 1> & /*inout_im*/) { throw mkl::unimplemented("DFT", "compute_forward", "MKLGPU does not support compute_forward(desc, inout_re, inout_im)."); } //Out-of-place transform -template -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in, - sycl::buffer &out) { +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { detail::expect_config(desc, "Unexpected value for placement"); @@ -111,11 +112,12 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer -ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer &in_re, - sycl::buffer &in_im, - sycl::buffer &out_re, - sycl::buffer &out_im) { +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> & /*in_re*/, + sycl::buffer, 1> & /*in_im*/, + sycl::buffer, 1> & /*out_re*/, + sycl::buffer, 1> & /*out_im*/) { detail::expect_config( desc, "Unexpected value for complex storage"); @@ -127,8 +129,8 @@ ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout, +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *inout, const std::vector &dependencies) { detail::expect_config( desc, "Unexpected value for placement"); @@ -136,17 +138,20 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inou } //In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, data_type *inout_re, - data_type *inout_im, - const std::vector &dependencies) { - throw mkl::unimplemented("DFT", "compute_forward", - "MKLGPU does not support compute_forward(desc, inout_re, inout_im)."); +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type & /*desc*/, + scalar * /*inout_re*/, + scalar * /*inout_im*/, + const std::vector & /*dependencies*/) { + throw mkl::unimplemented( + "DFT", "compute_forward", + "MKLGPU does not support compute_forward(desc, inout_re, inout_im, dependencies)."); } //Out-of-place transform -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in, output_type *out, +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *in, + bwd *out, const std::vector &dependencies) { detail::expect_config(desc, @@ -155,23 +160,22 @@ ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in, } //Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format -template -ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, input_type *in_re, - input_type *in_im, output_type *out_re, - output_type *out_im, - const std::vector &dependencies) { +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, + scalar * /*in_re*/, + scalar * /*in_im*/, + scalar * /*out_re*/, + scalar * /*out_im*/, + const std::vector & /*dependencies*/) { detail::expect_config( desc, "Unexpected value for complex storage"); throw oneapi::mkl::unimplemented( - "DFT", "compute_forward(desc, in_re, in_im, out_re, out_im, deps)", + "DFT", "compute_forward(desc, in_re, in_im, out_re, out_im, dependencies)", "MKLGPU does not support out-of-place FFT with real-real complex storage."); } // Template function instantiations #include "dft/backends/backend_forward_instantiations.cxx" -} // namespace mklgpu -} // namespace dft -} // namespace mkl -} // namespace oneapi +} // namespace oneapi::mkl::dft::mklgpu diff --git a/src/dft/backends/mklgpu/mklgpu_helpers.hpp b/src/dft/backends/mklgpu/mklgpu_helpers.hpp index ed9b8af2f..6813297ea 100644 --- a/src/dft/backends/mklgpu/mklgpu_helpers.hpp +++ b/src/dft/backends/mklgpu/mklgpu_helpers.hpp @@ -66,14 +66,14 @@ inline constexpr dft::config_param to_mklgpu(dft::detail::config_param param) { case iparam::COMPLEX_STORAGE: return oparam::COMPLEX_STORAGE; case iparam::REAL_STORAGE: return oparam::REAL_STORAGE; case iparam::CONJUGATE_EVEN_STORAGE: return oparam::CONJUGATE_EVEN_STORAGE; - case iparam::INPUT_STRIDES: return oparam::INPUT_STRIDES; - case iparam::OUTPUT_STRIDES: return oparam::OUTPUT_STRIDES; case iparam::FWD_DISTANCE: return oparam::FWD_DISTANCE; case iparam::BWD_DISTANCE: return oparam::BWD_DISTANCE; case iparam::WORKSPACE: return oparam::WORKSPACE; case iparam::ORDERING: return oparam::ORDERING; case iparam::TRANSPOSE: return oparam::TRANSPOSE; case iparam::PACKED_FORMAT: return oparam::PACKED_FORMAT; + case iparam::WORKSPACE_PLACEMENT: return oparam::WORKSPACE; // Same as WORKSPACE + case iparam::WORKSPACE_EXTERNAL_BYTES: return oparam::WORKSPACE_BYTES; case iparam::COMMIT_STATUS: return oparam::COMMIT_STATUS; default: throw mkl::invalid_argument("dft", "MKLGPU descriptor set_value()", @@ -143,6 +143,31 @@ inline constexpr int to_mklgpu( return 0; } } + +/** Convert a config_value to the backend's native value. Throw on invalid input. + * @tparam Param The config param the value is for. + * @param value The config value to convert. +**/ +template +inline constexpr dft::config_value to_mklgpu_config_value(dft::detail::config_value value); + +template <> +inline constexpr dft::config_value +to_mklgpu_config_value( + dft::detail::config_value value) { + if (value == dft::detail::config_value::WORKSPACE_AUTOMATIC) { + // NB: dft::config_value != dft::detail::config_value + return dft::config_value::WORKSPACE_INTERNAL; + } + else if (value == dft::detail::config_value::WORKSPACE_EXTERNAL) { + return dft::config_value::WORKSPACE_EXTERNAL; + } + else { + throw mkl::invalid_argument("dft", "MKLGPU descriptor set_value()", + "Invalid config value for workspace placement."); + return dft::config_value::WORKSPACE_INTERNAL; + } +} } // namespace detail } // namespace mklgpu } // namespace dft diff --git a/src/dft/backends/portfft/CMakeLists.txt b/src/dft/backends/portfft/CMakeLists.txt new file mode 100644 index 000000000..50e4d30d1 --- /dev/null +++ b/src/dft/backends/portfft/CMakeLists.txt @@ -0,0 +1,134 @@ +#=============================================================================== +# Copyright Codeplay Software Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + + +check_cxx_compiler_flag("-fsycl" IS_DPCPP) + +set(FOUND_TARGETS 0) + +if (NOT (CMAKE_CXX_FLAGS STREQUAL "")) + string(FIND ${CMAKE_CXX_FLAGS} "fsycl-targets" TARGETS_IDX) + if (TARGETS_IDX GREATER_EQUAL 0) + set(FOUND_TARGETS 1) + message(STATUS "fsycl-targets flag found, not setting targets") + endif() +endif() + +if (IS_DPCPP AND UNIX AND NOT FOUND_TARGETS) + message(WARNING "fsycl-targets flag not found, enabling all backends") + set(TARGETS_COMPILE_OPTIONS -fsycl-unnamed-lambda) + set(TARGETS_LINK_OPTIONS -fsycl-unnamed-lambda) + + # spir64 must be last in the list due to a bug in dpcpp 2024.0.0 + set(TARGETS_TRIPLES "spir64") + if(dpcpp_supports_nvptx64) + set(TARGETS_TRIPLES nvptx64-nvidia-cuda,${TARGETS_TRIPLES}) + endif() + + if (NOT (HIP_TARGETS STREQUAL "")) + set(TARGETS_TRIPLES amdgcn-amd-amdhsa,${TARGETS_TRIPLES}) + list(APPEND TARGETS_COMPILE_OPTIONS -Xsycl-target-backend=amdgcn-amd-amdhsa --offload-arch=${HIP_TARGETS}) + list(APPEND TARGETS_LINK_OPTIONS -Xsycl-target-backend=amdgcn-amd-amdhsa --offload-arch=${HIP_TARGETS}) + else() + message(WARNING "Can't enable hip backend, HIP_TARGETS has not been set.") + endif() + + message(STATUS "portFFT target triple set to ${TARGETS_TRIPLES}") + + list(APPEND TARGETS_COMPILE_OPTIONS -fsycl-targets=${TARGETS_TRIPLES}) + list(APPEND TARGETS_LINK_OPTIONS -fsycl-targets=${TARGETS_TRIPLES}) + + target_compile_options(ONEMKL::SYCL::SYCL INTERFACE ${TARGETS_COMPILE_OPTIONS}) + target_link_options(ONEMKL::SYCL::SYCL INTERFACE ${TARGETS_LINK_OPTIONS}) +endif() + +set(LIB_NAME onemkl_dft_portfft) +set(LIB_OBJ ${LIB_NAME}_obj) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + descriptor.cpp + commit.cpp + $<$: mkl_dft_portfft_wrappers.cpp> +) +add_dependencies(onemkl_backend_libs_dft ${LIB_NAME}) + +find_package(portfft QUIET) +if (NOT portfft_FOUND) + message(STATUS "portFFT - not found locally, downloading") + + include(FetchContent) + set(FETCHCONTENT_BASE_DIR "${CMAKE_BINARY_DIR}/deps") + FetchContent_Declare( + portfft + GIT_REPOSITORY https://github.com/codeplaysoftware/portFFT.git + GIT_TAG e4251e8ef89a8ac4d851a4cc08a0577a28f953e0 + ) + FetchContent_MakeAvailable(portfft) + message(STATUS "portFFT - downloaded") + target_link_libraries(${LIB_OBJ} PRIVATE portfft) +else() + message(STATUS "portFFT - found") + target_link_libraries(${LIB_OBJ} PRIVATE portfft::portfft) +endif() + +target_link_libraries(${LIB_OBJ} PRIVATE onemkl_warnings) + +target_include_directories(${LIB_OBJ} + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) +target_include_directories(${LIB_NAME} + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +#Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/dft/backends/portfft/commit.cpp b/src/dft/backends/portfft/commit.cpp new file mode 100644 index 000000000..3c0a6f186 --- /dev/null +++ b/src/dft/backends/portfft/commit.cpp @@ -0,0 +1,346 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include +#include + +#include + +#include "oneapi/mkl/exceptions.hpp" + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp" +#include "oneapi/mkl/dft/types.hpp" + +#include "../stride_helper.hpp" + +#include "portfft_helper.hpp" + +// alias to avoid ambiguity +namespace pfft = portfft; + +namespace oneapi::mkl::dft::portfft { +namespace detail { + +template +class portfft_commit final : public dft::detail::commit_impl { +private: + using scalar_type = typename dft::detail::commit_impl::scalar_type; + using fwd_type = typename dft::detail::commit_impl::fwd_type; + using bwd_type = typename dft::detail::commit_impl::bwd_type; + using descriptor_type = typename dft::detail::descriptor; + + static constexpr pfft::domain domain = + dom == dft::domain::REAL ? pfft::domain::REAL : pfft::domain::COMPLEX; + // since only complex-to-complex transforms are supported, we expect both directions to be valid or neither. + std::array, 2> committed_descriptors = { std::nullopt, + std::nullopt }; + +public: + portfft_commit(sycl::queue& queue, const dft::detail::dft_values& config_values) + : oneapi::mkl::dft::detail::commit_impl(queue, backend::portfft, + config_values) { + if constexpr (prec == dft::detail::precision::DOUBLE) { + if (!queue.get_device().has(sycl::aspect::fp64)) { + throw mkl::exception("DFT", "commit", "Device does not support double precision."); + } + } + } + + void commit(const dft::detail::dft_values& config_values) override { + // not available in portFFT: + this->external_workspace_helper_ = + oneapi::mkl::dft::detail::external_workspace_helper( + config_values.workspace_placement == + oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL); + if (config_values.workspace != config_value::ALLOW) { + throw mkl::unimplemented("dft/backends/portfft", __FUNCTION__, + "portFFT only supports ALLOW for the WORKSPACE parameter"); + } + if (config_values.ordering != config_value::ORDERED) { + throw mkl::unimplemented("dft/backends/portfft", __FUNCTION__, + "portFFT only supports ORDERED for the ORDERING parameter"); + } + if (config_values.transpose) { + throw mkl::unimplemented("dft/backends/portfft", __FUNCTION__, + "portFFT does not supported transposed output"); + } + + auto stride_api_choice = dft::detail::get_stride_api(config_values); + dft::detail::throw_on_invalid_stride_api("portFFT commit", stride_api_choice); + dft::detail::stride_vectors stride_vecs(config_values, stride_api_choice); + + // forward descriptor + pfft::descriptor fwd_desc( + { config_values.dimensions.cbegin(), config_values.dimensions.cend() }); + fwd_desc.forward_scale = config_values.fwd_scale; + fwd_desc.backward_scale = config_values.bwd_scale; + fwd_desc.number_of_transforms = + static_cast(config_values.number_of_transforms); + fwd_desc.complex_storage = config_values.complex_storage == config_value::COMPLEX_COMPLEX + ? pfft::complex_storage::INTERLEAVED_COMPLEX + : pfft::complex_storage::SPLIT_COMPLEX; + fwd_desc.placement = config_values.placement == config_value::INPLACE + ? pfft::placement::IN_PLACE + : pfft::placement::OUT_OF_PLACE; + fwd_desc.forward_offset = static_cast(stride_vecs.offset_fwd_in); + fwd_desc.backward_offset = static_cast(stride_vecs.offset_fwd_out); + fwd_desc.forward_strides = { stride_vecs.fwd_in.cbegin() + 1, stride_vecs.fwd_in.cend() }; + fwd_desc.backward_strides = { stride_vecs.fwd_out.cbegin() + 1, + stride_vecs.fwd_out.cend() }; + fwd_desc.forward_distance = static_cast(config_values.fwd_dist); + fwd_desc.backward_distance = static_cast(config_values.bwd_dist); + + // backward descriptor + pfft::descriptor bwd_desc( + { config_values.dimensions.cbegin(), config_values.dimensions.cend() }); + bwd_desc.forward_scale = config_values.fwd_scale; + bwd_desc.backward_scale = config_values.bwd_scale; + bwd_desc.number_of_transforms = + static_cast(config_values.number_of_transforms); + bwd_desc.complex_storage = config_values.complex_storage == config_value::COMPLEX_COMPLEX + ? pfft::complex_storage::INTERLEAVED_COMPLEX + : pfft::complex_storage::SPLIT_COMPLEX; + bwd_desc.placement = config_values.placement == config_value::INPLACE + ? pfft::placement::IN_PLACE + : pfft::placement::OUT_OF_PLACE; + bwd_desc.forward_offset = static_cast(stride_vecs.offset_bwd_out); + bwd_desc.backward_offset = static_cast(stride_vecs.offset_bwd_in); + bwd_desc.forward_strides = { stride_vecs.bwd_out.cbegin() + 1, stride_vecs.bwd_out.cend() }; + bwd_desc.backward_strides = { stride_vecs.bwd_in.cbegin() + 1, + stride_vecs.bwd_in.cend() }; + bwd_desc.forward_distance = static_cast(config_values.fwd_dist); + bwd_desc.backward_distance = static_cast(config_values.bwd_dist); + + try { + auto q = this->get_queue(); + committed_descriptors[0] = fwd_desc.commit(q); + committed_descriptors[1] = bwd_desc.commit(q); + } + catch (const pfft::unsupported_configuration& e) { + throw oneapi::mkl::unimplemented("dft/backends/portfft", __FUNCTION__, e.what()); + } + } + + ~portfft_commit() override = default; + + void* get_handle() noexcept override { + return committed_descriptors.data(); + } + + // All the compute functions are implementated here so they are in the same translation unit as the commit function. + // If the use of the kernel bundle is in a seperate translation unit from the one it was translated in, the runtime can fail to find it. + + // forward inplace COMPLEX_COMPLEX + void forward_ip_cc(descriptor_type& desc, sycl::buffer& inout) override { + constexpr auto pfft_domain = detail::to_pfft_domain::type::value; + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_forward"); + + if constexpr (pfft_domain == pfft::domain::COMPLEX) { + detail::get_descriptors(desc)[0]->compute_forward(inout); + } + } + sycl::event forward_ip_cc(descriptor_type& desc, fwd_type* inout, + const std::vector& dependencies) override { + constexpr auto pfft_domain = detail::to_pfft_domain::type::value; + dft::detail::get_commit(desc)->template compute_call_throw("compute_forward"); + + if constexpr (pfft_domain == pfft::domain::COMPLEX) { + return detail::get_descriptors(desc)[0]->compute_forward(inout, dependencies); + } + else { + return {}; + } + } + + // forward inplace REAL_REAL + void forward_ip_rr(descriptor_type& desc, sycl::buffer&, + sycl::buffer&) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_forward"); + throw oneapi::mkl::unimplemented("DFT", "compute_forward(desc, inout_re, inout_im)", + "portFFT does not support real-real complex storage."); + } + sycl::event forward_ip_rr(descriptor_type& desc, scalar_type*, scalar_type*, + const std::vector&) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_forward"); + throw oneapi::mkl::unimplemented("DFT", + "compute_forward(desc, inout_re, inout_im, dependencies)", + "portFFT does not support real-real complex storage."); + } + + // forward out-of-place COMPLEX_COMPLEX + void forward_op_cc(descriptor_type& desc, sycl::buffer& in, + sycl::buffer& out) override { + constexpr auto pfft_domain = detail::to_pfft_domain::type::value; + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_forward"); + + if constexpr (pfft_domain == pfft::domain::COMPLEX) { + detail::get_descriptors(desc)[0]->compute_forward(in, out); + } + } + sycl::event forward_op_cc(descriptor_type& desc, fwd_type* in, bwd_type* out, + const std::vector& dependencies) override { + constexpr auto pfft_domain = detail::to_pfft_domain::type::value; + dft::detail::get_commit(desc)->template compute_call_throw("compute_forward"); + + if constexpr (pfft_domain == pfft::domain::COMPLEX) { + return detail::get_descriptors(desc)[0]->compute_forward(in, out, dependencies); + } + else { + return {}; + } + } + + // forward out-of-place REAL_REAL + void forward_op_rr(descriptor_type& desc, sycl::buffer&, + sycl::buffer&, sycl::buffer&, + sycl::buffer&) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_forward"); + throw oneapi::mkl::unimplemented("DFT", + "compute_forward(desc, in_re, in_im, out_re, out_im)", + "portFFT does not support real-real complex storage."); + } + sycl::event forward_op_rr(descriptor_type& desc, scalar_type*, scalar_type*, scalar_type*, + scalar_type*, const std::vector&) override { + dft::detail::get_commit(desc)->template compute_call_throw("compute_forward"); + throw oneapi::mkl::unimplemented( + "DFT", "compute_forward(desc, in_re, in_im, out_re, out_im, dependencies)", + "portFFT does not support real-real complex storage."); + } + + // backward inplace COMPLEX_COMPLEX + void backward_ip_cc(descriptor_type& desc, sycl::buffer& inout) override { + constexpr auto pfft_domain = detail::to_pfft_domain::type::value; + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_backward"); + + if constexpr (pfft_domain == pfft::domain::COMPLEX) { + detail::get_descriptors(desc)[1]->compute_backward(inout); + } + } + sycl::event backward_ip_cc(descriptor_type& desc, fwd_type* inout, + const std::vector& dependencies) override { + constexpr auto pfft_domain = detail::to_pfft_domain::type::value; + dft::detail::get_commit(desc)->template compute_call_throw("compute_backward"); + + if constexpr (pfft_domain == pfft::domain::COMPLEX) { + return detail::get_descriptors(desc)[1]->compute_backward(inout, dependencies); + } + else { + return {}; + } + } + + // backward inplace REAL_REAL + void backward_ip_rr(descriptor_type& desc, sycl::buffer&, + sycl::buffer&) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_backward"); + throw oneapi::mkl::unimplemented("DFT", "compute_backward(desc, inout_re, inout_im)", + "portFFT does not support real-real complex storage."); + } + sycl::event backward_ip_rr(descriptor_type& desc, scalar_type*, scalar_type*, + const std::vector&) override { + dft::detail::get_commit(desc)->template compute_call_throw( + "compute_backward"); + throw oneapi::mkl::unimplemented("DFT", + "compute_backward(desc, inout_re, inout_im, dependencies)", + "portFFT does not support real-real complex storage."); + } + + // backward out-of-place COMPLEX_COMPLEX + void backward_op_cc(descriptor_type& desc, sycl::buffer& in, + sycl::buffer& out) override { + constexpr auto pfft_domain = detail::to_pfft_domain::type::value; + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_backward"); + + if constexpr (pfft_domain == pfft::domain::COMPLEX) { + detail::get_descriptors(desc)[1]->compute_backward(in, out); + } + } + sycl::event backward_op_cc(descriptor_type& desc, bwd_type* in, fwd_type* out, + const std::vector& dependencies) override { + constexpr auto pfft_domain = detail::to_pfft_domain::type::value; + dft::detail::get_commit(desc)->template compute_call_throw("compute_backward"); + + if constexpr (pfft_domain == pfft::domain::COMPLEX) { + return detail::get_descriptors(desc)[1]->compute_backward(in, out, dependencies); + } + else { + return {}; + } + } + + // backward out-of-place REAL_REAL + void backward_op_rr(descriptor_type& desc, sycl::buffer&, + sycl::buffer&, sycl::buffer&, + sycl::buffer&) override { + dft::detail::get_commit(desc)->template compute_call_throw>( + "compute_backward"); + throw oneapi::mkl::unimplemented("DFT", + "compute_backward(desc, in_re, in_im, out_re, out_im)", + "portFFT does not support real-real complex storage."); + } + sycl::event backward_op_rr(descriptor_type& desc, scalar_type*, scalar_type*, scalar_type*, + scalar_type*, const std::vector&) override { + dft::detail::get_commit(desc)->template compute_call_throw( + "compute_backward"); + throw oneapi::mkl::unimplemented( + "DFT", "compute_backward(desc, in_re, in_im, out_re, out_im, deps)", + "portFFT does not support real-real complex storage."); + } +}; +} // namespace detail + +template +dft::detail::commit_impl* create_commit(const dft::detail::descriptor& desc, + sycl::queue& sycl_queue) { + return new detail::portfft_commit(sycl_queue, desc.get_values()); +} + +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); + +} // namespace oneapi::mkl::dft::portfft diff --git a/src/dft/backends/portfft/descriptor.cpp b/src/dft/backends/portfft/descriptor.cpp new file mode 100644 index 000000000..d72d23bb5 --- /dev/null +++ b/src/dft/backends/portfft/descriptor.cpp @@ -0,0 +1,47 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/descriptor.hpp" +#include "../../descriptor.cxx" + +#include "oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp" + +namespace oneapi::mkl::dft { + +template +void descriptor::commit(backend_selector selector) { + if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) { + if (pimpl_) { + pimpl_->get_queue().wait(); + } + pimpl_.reset(portfft::create_commit(*this, selector.get_queue())); + } + pimpl_->commit(values_); +} + +template void descriptor::commit( + backend_selector); +template void descriptor::commit( + backend_selector); +template void descriptor::commit( + backend_selector); +template void descriptor::commit( + backend_selector); + +} // namespace oneapi::mkl::dft diff --git a/src/dft/backends/portfft/mkl_dft_portfft_wrappers.cpp b/src/dft/backends/portfft/mkl_dft_portfft_wrappers.cpp new file mode 100644 index 000000000..28996b0a1 --- /dev/null +++ b/src/dft/backends/portfft/mkl_dft_portfft_wrappers.cpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/detail/portfft/onemkl_dft_portfft.hpp" +#include "dft/function_table.hpp" + +#define WRAPPER_VERSION 1 +#define BACKEND portfft + +extern "C" dft_function_table_t mkl_dft_table = { + WRAPPER_VERSION, +#include "dft/backends/backend_wrappers.cxx" +}; + +#undef WRAPPER_VERSION +#undef BACKEND diff --git a/src/dft/backends/portfft/portfft_helper.hpp b/src/dft/backends/portfft/portfft_helper.hpp new file mode 100644 index 000000000..373865f49 --- /dev/null +++ b/src/dft/backends/portfft/portfft_helper.hpp @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_PORTFFT_HELPERS_HPP_ +#define _ONEMKL_DFT_SRC_PORTFFT_HELPERS_HPP_ + +#include + +#include + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" + +namespace pfft = portfft; + +namespace oneapi::mkl::dft::portfft::detail { +template +inline dft::detail::commit_impl *checked_get_commit( + dft::detail::descriptor &desc) { + auto commit_handle = dft::detail::get_commit(desc); + if (commit_handle == nullptr || commit_handle->get_backend() != backend::portfft) { + throw mkl::invalid_argument("dft/backends/portfft", "get_commit", + "DFT descriptor has not been commited for portFFT"); + } + return commit_handle; +} + +template +using to_pfft_domain = + std::conditional>, + std::integral_constant, + std::integral_constant>; + +template +using storage_type = + std::optional, + detail::to_pfft_domain::type::value>>; + +template +auto get_descriptors(descriptor_type &desc) { + auto commit = detail::checked_get_commit(desc); + return reinterpret_cast *>(commit->get_handle()); +} +} // namespace oneapi::mkl::dft::portfft::detail + +#endif diff --git a/src/dft/backends/rocfft/CMakeLists.txt b/src/dft/backends/rocfft/CMakeLists.txt new file mode 100644 index 000000000..1380c8f0a --- /dev/null +++ b/src/dft/backends/rocfft/CMakeLists.txt @@ -0,0 +1,95 @@ +#=============================================================================== +# Copyright Codeplay Software Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemkl_dft_rocfft) +set(LIB_OBJ ${LIB_NAME}_obj) + + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + descriptor.cpp + commit.cpp + forward.cpp + backward.cpp + $<$: mkl_dft_rocfft_wrappers.cpp> +) +add_dependencies(onemkl_backend_libs_dft ${LIB_NAME}) + +target_include_directories(${LIB_OBJ} + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) +target_include_directories(${LIB_NAME} + PUBLIC ${ONEMKL_INTERFACE_INCLUDE_DIRS} +) +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + +find_package(HIP REQUIRED) +# Require the minimum rocFFT version matching with ROCm 5.4.3. +find_package(rocfft 1.0.21 REQUIRED) + +target_link_libraries(${LIB_OBJ} PRIVATE hip::host roc::rocfft) + +# Allow to compile for different ROCm versions. See the README for the supported +# ROCm versions. +# Starting ROCm >=6.0 the include files are one directory level deeper. +find_path( + rocfft_EXTRA_INCLUDE_DIR + rocfft.h + PATHS ${rocfft_INCLUDE_DIR} + PATH_SUFFIXES rocfft + NO_DEFAULT_PATH + REQUIRED +) +target_include_directories(${LIB_OBJ} PRIVATE ${rocfft_EXTRA_INCLUDE_DIR}) + +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL) + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +# Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/dft/backends/rocfft/backward.cpp b/src/dft/backends/rocfft/backward.cpp new file mode 100644 index 000000000..5ff0e2a1f --- /dev/null +++ b/src/dft/backends/rocfft/backward.cpp @@ -0,0 +1,357 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/exceptions.hpp" + +#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" + +#include "execute_helper.hpp" +#include "rocfft_handle.hpp" + +#include +#include + +namespace oneapi::mkl::dft::rocfft { +namespace detail { +//forward declaration +template +std::array get_offsets_bwd(dft::detail::commit_impl *commit); + +template +rocfft_plan get_bwd_plan(dft::detail::commit_impl *commit) { + return static_cast(commit->get_handle())[1].plan.value(); +} + +template +rocfft_execution_info get_bwd_info(dft::detail::commit_impl *commit) { + return static_cast(commit->get_handle())[1].info.value(); +} +} // namespace detail +// BUFFER version + +//In-place transform +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &inout) { + const std::string func_name = "compute_backward(desc, inout)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + auto offsets = detail::get_offsets_bwd(commit); + + if constexpr (std::is_floating_point_v>) { + offsets[0] *= 2; // offset is supplied in complex but we offset scalar pointer + } + if (offsets[0] != offsets[1]) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "rocFFT requires input and output offsets (first value in strides) to be equal for in-place transforms!"); + } + + queue.submit([&](sycl::handler &cgh) { + auto inout_acc = inout.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, info); + + auto inout_native = reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, inout_acc)) + + offsets[0]); + detail::execute_checked(func_name, plan, &inout_native, nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &inout_re, + sycl::buffer, 1> &inout_im) { + const std::string func_name = "compute_backward(desc, inout_re, inout_im)"; + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + auto offsets = detail::get_offsets_bwd(commit); + + if (offsets[0] != offsets[1]) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "rocFFT requires input and output offsets (first value in strides) to be equal for in-place transforms!"); + } + + queue.submit([&](sycl::handler &cgh) { + auto inout_re_acc = inout_re.template get_access(cgh); + auto inout_im_acc = inout_im.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, info); + + std::array inout_native{ + reinterpret_cast(reinterpret_cast *>( + detail::native_mem(ih, inout_re_acc)) + + offsets[0]), + reinterpret_cast(reinterpret_cast *>( + detail::native_mem(ih, inout_im_acc)) + + offsets[0]) + }; + detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + auto offsets = detail::get_offsets_bwd(commit); + + queue.submit([&](sycl::handler &cgh) { + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, in, out)"; + auto stream = detail::setup_stream(func_name, ih, info); + + auto in_native = reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, in_acc)) + + offsets[0]); + auto out_native = reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, out_acc)) + + offsets[1]); + detail::execute_checked(func_name, plan, &in_native, &out_native, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_backward(descriptor_type &desc, + sycl::buffer, 1> &in_re, + sycl::buffer, 1> &in_im, + sycl::buffer, 1> &out_re, + sycl::buffer, 1> &out_im) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + auto offsets = detail::get_offsets_bwd(commit); + + queue.submit([&](sycl::handler &cgh) { + auto in_re_acc = in_re.template get_access(cgh); + auto in_im_acc = in_im.template get_access(cgh); + auto out_re_acc = out_re.template get_access(cgh); + auto out_im_acc = out_im.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_backward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, in_re, in_im, out_re, out_im)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array in_native{ + reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, in_re_acc)) + + offsets[0]), + reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, in_im_acc)) + + offsets[0]) + }; + std::array out_native{ + reinterpret_cast(reinterpret_cast *>( + detail::native_mem(ih, out_re_acc)) + + offsets[1]), + reinterpret_cast(reinterpret_cast *>( + detail::native_mem(ih, out_im_acc)) + + offsets[1]) + }; + detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//USM version + +//In-place transform +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, fwd *inout, + const std::vector &deps) { + const std::string func_name = "compute_backward(desc, inout, deps)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + auto offsets = detail::get_offsets_bwd(commit); + + if constexpr (std::is_floating_point_v>) { + offsets[0] *= 2; // offset is supplied in complex but we offset scalar pointer + } + if (offsets[0] != offsets[1]) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "rocFFT requires input and output offsets (first value in strides) to be equal for in-place transforms!"); + } + inout += offsets[0]; + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, info); + + void *inout_ptr = inout; + detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar *inout_re, + scalar *inout_im, + const std::vector &deps) { + const std::string func_name = "compute_backward(desc, inout_re, inout_im, deps)"; + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + auto offsets = detail::get_offsets_bwd(commit); + + if (offsets[0] != offsets[1]) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "rocFFT requires input and output offsets (first value in strides) to be equal for in-place transforms!"); + } + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, info); + + std::array inout_native{ inout_re + offsets[0], inout_im + offsets[0] }; + detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//Out-of-place transform +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, bwd *in, + fwd *out, + const std::vector &deps) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + auto offsets = detail::get_offsets_bwd(commit); + + in += offsets[0]; + out += offsets[1]; + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_backward(desc, in, out, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + void *in_ptr = in; + void *out_ptr = out; + detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info); + detail::sync_checked(func_name, stream); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_backward(descriptor_type &desc, scalar *in_re, + scalar *in_im, + scalar *out_re, + scalar *out_im, + const std::vector &deps) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_bwd_plan(commit); + auto info = detail::get_bwd_info(commit); + auto offsets = detail::get_offsets_bwd(commit); + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = + "compute_backward(desc, in_re, in_im, out_re, out_im, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array in_native{ in_re + offsets[0], in_im + offsets[0] }; + std::array out_native{ out_re + offsets[1], out_im + offsets[1] }; + detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); + detail::sync_checked(func_name, stream); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +// Template function instantiations +#include "dft/backends/backend_backward_instantiations.cxx" + +} // namespace oneapi::mkl::dft::rocfft diff --git a/src/dft/backends/rocfft/commit.cpp b/src/dft/backends/rocfft/commit.cpp new file mode 100644 index 000000000..db5a7f965 --- /dev/null +++ b/src/dft/backends/rocfft/commit.cpp @@ -0,0 +1,640 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#if __has_include() +#include +#else +#include +#endif + +#include +#include +#include + +#include "oneapi/mkl/exceptions.hpp" + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp" +#include "oneapi/mkl/dft/types.hpp" + +#include "../stride_helper.hpp" + +#include "rocfft_handle.hpp" + +#include +#include + +namespace oneapi::mkl::dft::rocfft { +namespace detail { + +// rocfft has global setup and cleanup functions which use some global state internally. +// Each can be called multiple times in an application, but due to the global nature, they always need to alternate. +// I don't believe its possible to avoid the user calling rocfft_cleanup in their own code, +// breaking our code, but we can try avoid it for them. +// rocfft_cleanup internally uses some singletons, so it is very difficult to decide if this is safe due to +// the static initialisation order fiasco. +class rocfft_singleton { + rocfft_singleton() { + const auto result = rocfft_setup(); + if (result != rocfft_status_success) { + throw mkl::exception( + "DFT", "rocfft", + "Failed to setup rocfft. returned status " + std::to_string(result)); + } + } + + ~rocfft_singleton() { + (void)rocfft_cleanup(); + } + + // no copies or moves allowed + rocfft_singleton(const rocfft_singleton& other) = delete; + rocfft_singleton(rocfft_singleton&& other) noexcept = delete; + rocfft_singleton& operator=(const rocfft_singleton& other) = delete; + rocfft_singleton& operator=(rocfft_singleton&& other) noexcept = delete; + +public: + static void init() { + static rocfft_singleton instance; + (void)instance; + } +}; + +/// Commit impl class specialization for rocFFT. +template +class rocfft_commit final : public dft::detail::commit_impl { +private: + using scalar_type = typename dft::detail::commit_impl::scalar_type; + // For real to complex transforms, the "transform_type" arg also encodes the direction (e.g. rocfft_transform_type_*_forward vs rocfft_transform_type_*_backward) + // in the plan so we must have one for each direction. + // We also need this because oneMKL uses a directionless "FWD_DISTANCE" and "BWD_DISTANCE" while rocFFT uses a directional "in_distance" and "out_distance". + // The same is also true for "FORWARD_SCALE" and "BACKWARD_SCALE". + // handles[0] is forward, handles[1] is backward + std::array handles{}; + std::int64_t offset_fwd_in, offset_fwd_out, offset_bwd_in, offset_bwd_out; + +public: + rocfft_commit(sycl::queue& queue, const dft::detail::dft_values& config_values) + : oneapi::mkl::dft::detail::commit_impl(queue, backend::rocfft, + config_values) { + if constexpr (prec == dft::detail::precision::DOUBLE) { + if (!queue.get_device().has(sycl::aspect::fp64)) { + throw mkl::exception("DFT", "commit", "Device does not support double precision."); + } + } + // initialise the rocFFT global state + rocfft_singleton::init(); + } + + void clean_plans() { + if (handles[0].plan) { + if (rocfft_plan_destroy(handles[0].plan.value()) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy forward plan."); + } + handles[0].plan = std::nullopt; + } + if (handles[1].plan) { + if (rocfft_plan_destroy(handles[1].plan.value()) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy backward plan."); + } + handles[1].plan = std::nullopt; + } + + if (handles[0].info) { + if (rocfft_execution_info_destroy(handles[0].info.value()) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy forward execution info ."); + } + handles[0].info = std::nullopt; + } + if (handles[1].info) { + if (rocfft_execution_info_destroy(handles[1].info.value()) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy backward execution info ."); + } + handles[1].info = std::nullopt; + } + free_internal_workspace_if_rqd(handles[0], "clear_plans"); + free_internal_workspace_if_rqd(handles[1], "clear_plans"); + } + + void commit(const dft::detail::dft_values& config_values) override { + // this could be a recommit + this->external_workspace_helper_ = + oneapi::mkl::dft::detail::external_workspace_helper( + config_values.workspace_placement == + oneapi::mkl::dft::detail::config_value::WORKSPACE_EXTERNAL); + clean_plans(); + + const rocfft_result_placement placement = + (config_values.placement == dft::config_value::INPLACE) ? rocfft_placement_inplace + : rocfft_placement_notinplace; + + constexpr rocfft_transform_type fwd_type = [] { + if constexpr (dom == dft::domain::COMPLEX) { + return rocfft_transform_type_complex_forward; + } + else { + return rocfft_transform_type_real_forward; + } + }(); + constexpr rocfft_transform_type bwd_type = [] { + if constexpr (dom == dft::domain::COMPLEX) { + return rocfft_transform_type_complex_inverse; + } + else { + return rocfft_transform_type_real_inverse; + } + }(); + + constexpr rocfft_precision precision = [] { + if constexpr (prec == dft::precision::SINGLE) { + return rocfft_precision_single; + } + else { + return rocfft_precision_double; + } + }(); + + const std::size_t dimensions = config_values.dimensions.size(); + + constexpr std::size_t max_supported_dims = 3; + std::array lengths; + // rocfft does dimensions in the reverse order to oneMKL + std::copy(config_values.dimensions.crbegin(), config_values.dimensions.crend(), + lengths.data()); + + const std::size_t number_of_transforms = + static_cast(config_values.number_of_transforms); + + const std::size_t fwd_distance = static_cast(config_values.fwd_dist); + const std::size_t bwd_distance = static_cast(config_values.bwd_dist); + + const rocfft_array_type fwd_array_ty = [&config_values]() { + if constexpr (dom == dft::domain::COMPLEX) { + if (config_values.complex_storage == dft::config_value::COMPLEX_COMPLEX) { + return rocfft_array_type_complex_interleaved; + } + else { + return rocfft_array_type_complex_planar; + } + } + else { + return rocfft_array_type_real; + } + }(); + const rocfft_array_type bwd_array_ty = [&config_values]() { + if constexpr (dom == dft::domain::COMPLEX) { + if (config_values.complex_storage == dft::config_value::COMPLEX_COMPLEX) { + return rocfft_array_type_complex_interleaved; + } + else { + return rocfft_array_type_complex_planar; + } + } + else { + if (config_values.conj_even_storage != dft::config_value::COMPLEX_COMPLEX) { + throw mkl::exception( + "dft/backends/rocfft", __FUNCTION__, + "only COMPLEX_COMPLEX conjugate_even_storage is supported"); + } + return rocfft_array_type_hermitian_interleaved; + } + }(); + + auto stride_api_choice = dft::detail::get_stride_api(config_values); + dft::detail::throw_on_invalid_stride_api("ROCFFT commit", stride_api_choice); + dft::detail::stride_vectors stride_vecs(config_values, stride_api_choice); + + // while rocfft interface accepts offsets, it does not actually handle them + offset_fwd_in = stride_vecs.offset_fwd_in; + offset_fwd_out = stride_vecs.offset_fwd_out; + offset_bwd_in = stride_vecs.offset_bwd_in; + offset_bwd_out = stride_vecs.offset_bwd_out; + + auto func = __FUNCTION__; + auto check_strides = [&](const auto& strides) { + for (int i = 1; i <= dimensions; i++) { + for (int j = 1; j <= dimensions; j++) { + std::int64_t cplx_dim = config_values.dimensions[j - 1]; + std::int64_t real_dim = (dom == dft::domain::REAL && j == dimensions) + ? (cplx_dim / 2 + 1) + : cplx_dim; + if (strides[i] > strides[j] && strides[i] % cplx_dim != 0 && + strides[i] % real_dim != 0) { + // rocfft does not throw, it just produces wrong results + throw oneapi::mkl::unimplemented( + "DFT", func, + "rocfft requires a stride to be divisible by all dimensions associated with smaller strides!"); + } + } + } + }; + // bwd_in/out alias fwd_in/out, so no need to check everything. + check_strides(stride_vecs.vec_a); + check_strides(stride_vecs.vec_b); + + // Reformat slides to conform to rocFFT API. + std::reverse(stride_vecs.vec_a.begin(), stride_vecs.vec_a.end()); + stride_vecs.vec_a.pop_back(); // Offset is not included. + std::reverse(stride_vecs.vec_b.begin(), stride_vecs.vec_b.end()); + stride_vecs.vec_b.pop_back(); // Offset is not included. + + rocfft_plan_description plan_desc; + if (rocfft_plan_description_create(&plan_desc) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create plan description."); + } + + // plan_description can be destroyed afted plan_create + auto description_destroy = [](rocfft_plan_description p) { + if (rocfft_plan_description_destroy(p) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to destroy plan description."); + } + }; + std::unique_ptr + description_destroyer(plan_desc, description_destroy); + + std::array stride_a_indices{ 0, 1, 2 }; + std::sort(&stride_a_indices[0], &stride_a_indices[dimensions], + [&](std::size_t a, std::size_t b) { + return stride_vecs.vec_a[a] < stride_vecs.vec_a[b]; + }); + std::array stride_b_indices{ 0, 1, 2 }; + std::sort(&stride_b_indices[0], &stride_b_indices[dimensions], + [&](std::size_t a, std::size_t b) { + return stride_vecs.vec_b[a] < stride_vecs.vec_b[b]; + }); + std::array lengths_cplx = lengths; + if (dom == dft::domain::REAL) { + lengths_cplx[0] = lengths_cplx[0] / 2 + 1; + } + // When creating real-complex descriptions, the strides will always be wrong for one of the directions. + // This is because the least significant dimension is symmetric. + // If the strides are invalid (too small to fit) then just don't bother creating the plan. + const bool vec_a_valid_as_reals = + dimensions == 1 || + (lengths_cplx[stride_a_indices[0]] <= stride_vecs.vec_a[stride_a_indices[1]] && + (dimensions == 2 || + lengths_cplx[stride_a_indices[0]] * lengths_cplx[stride_a_indices[1]] <= + stride_vecs.vec_a[stride_a_indices[2]])); + const bool vec_b_valid_as_reals = + dimensions == 1 || + (lengths_cplx[stride_b_indices[0]] <= stride_vecs.vec_b[stride_b_indices[1]] && + (dimensions == 2 || + lengths_cplx[stride_b_indices[0]] * lengths_cplx[stride_b_indices[1]] <= + stride_vecs.vec_b[stride_b_indices[2]])); + // Test if the stride vector being used as the fwd domain for each direction has valid strides for that use. + bool valid_forward = + stride_vecs.fwd_in == stride_vecs.vec_a && vec_a_valid_as_reals || vec_b_valid_as_reals; + bool valid_backward = stride_vecs.bwd_out == stride_vecs.vec_a && vec_a_valid_as_reals || + vec_b_valid_as_reals; + + if (!valid_forward && !valid_backward) { + throw mkl::exception("dft/backends/cufft", __FUNCTION__, "Invalid strides."); + } + + if (valid_forward) { + auto res = + rocfft_plan_description_set_data_layout(plan_desc, fwd_array_ty, bwd_array_ty, + nullptr, // in offsets + nullptr, // out offsets + dimensions, + stride_vecs.fwd_in.data(), //in strides + fwd_distance, // in distance + dimensions, + stride_vecs.fwd_out.data(), // out strides + bwd_distance // out distance + ); + if (res != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set forward data layout."); + } + + if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.fwd_scale) != + rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set forward scale factor."); + } + + rocfft_plan fwd_plan; + res = rocfft_plan_create(&fwd_plan, placement, fwd_type, precision, dimensions, + lengths.data(), number_of_transforms, plan_desc); + + if (res != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create forward plan."); + } + + handles[0].plan = fwd_plan; + + rocfft_execution_info fwd_info; + if (rocfft_execution_info_create(&fwd_info) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create forward execution info."); + } + handles[0].info = fwd_info; + + if (config_values.workspace_placement == config_value::WORKSPACE_AUTOMATIC) { + std::int64_t work_buf_size = get_rocfft_workspace_bytes(handles[0], "commit"); + if (work_buf_size != 0) { + void* work_buf; + if (hipMalloc(&work_buf, work_buf_size) != hipSuccess) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to get allocate forward work buffer."); + } + set_workspace_impl(handles[0], reinterpret_cast(work_buf), + work_buf_size, "commit"); + handles[0].buffer = work_buf; + } + } + } + + if (valid_backward) { + auto res = + rocfft_plan_description_set_data_layout(plan_desc, bwd_array_ty, fwd_array_ty, + nullptr, // in offsets + nullptr, // out offsets + dimensions, + stride_vecs.bwd_in.data(), //in strides + bwd_distance, // in distance + dimensions, + stride_vecs.bwd_out.data(), // out strides + fwd_distance // out distance + ); + if (res != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set backward data layout."); + } + + if (rocfft_plan_description_set_scale_factor(plan_desc, config_values.bwd_scale) != + rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to set backward scale factor."); + } + + rocfft_plan bwd_plan; + res = rocfft_plan_create(&bwd_plan, placement, bwd_type, precision, dimensions, + lengths.data(), number_of_transforms, plan_desc); + if (res != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create backward rocFFT plan."); + } + handles[1].plan = bwd_plan; + + rocfft_execution_info bwd_info; + if (rocfft_execution_info_create(&bwd_info) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to create backward execution info."); + } + handles[1].info = bwd_info; + + if (config_values.workspace_placement == config_value::WORKSPACE_AUTOMATIC) { + std::int64_t work_buf_size = get_rocfft_workspace_bytes(handles[1], "commit"); + if (work_buf_size != 0) { + void* work_buf; + if (hipMalloc(&work_buf, work_buf_size) != hipSuccess) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to get allocate backward work buffer."); + } + set_workspace_impl(handles[1], reinterpret_cast(work_buf), + work_buf_size, "commit"); + handles[1].buffer = work_buf; + } + } + } + } + + ~rocfft_commit() override { + clean_plans(); + } + + // Rule of three. Copying could lead to memory safety issues. + rocfft_commit(const rocfft_commit& other) = delete; + rocfft_commit& operator=(const rocfft_commit& other) = delete; + + void* get_handle() noexcept override { + return handles.data(); + } + + std::array get_offsets_fwd() noexcept { + return { offset_fwd_in, offset_fwd_out }; + } + + std::array get_offsets_bwd() noexcept { + return { offset_bwd_in, offset_bwd_out }; + } + + /** Get the requried worspace size for a rocfft plan. Implementation to be shared by internal and external workspace mechanisms. + + * @param handle rocfft_handle. Expected to have valid rocfft_plan. + * @param function The name of the function to give when generating exceptions + * @return Required space in bytes + **/ + std::int64_t get_rocfft_workspace_bytes(rocfft_handle& handle, const char* function) { + if (!handle.plan) { + throw mkl::exception("dft/backends/rocfft", function, "Missing internal rocfft plan"); + } + std::size_t size = 0; + if (rocfft_plan_get_work_buffer_size(*handle.plan, &size) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", function, + "Failed to get rocfft work buffer size."); + } + return static_cast(size); + } + + /** Set the rocFFT workspace. Implementation to be shared by internal workspace allocation and external workspace + * mechanisms. Does not set handle.buffer. + * + * @param handle rocfft_handle. Expected to have valid rocfft_plan and rocfft_info, but no associated buffer. + * @param workspace Pointer to allocation to use as workspace + * @param workspace_bytes The size (in bytes) of the given workspace + * @param function The name of the function to give when generating exceptions + **/ + void set_workspace_impl(const rocfft_handle& handle, scalar_type* workspace, + std::int64_t workspace_bytes, const char* function) { + if (!handle.info) { + throw mkl::exception( + "dft/backends/rocfft", function, + "Could not set rocFFT workspace - handle has no associated rocfft_info."); + } + if (handle.buffer) { + throw mkl::exception( + "dft/backends/rocfft", function, + "Could not set rocFFT workspace - an internal buffer is already set."); + } + if (workspace_bytes && workspace == nullptr) { + throw mkl::exception("dft/backends/rocfft", function, "Trying to nullptr workspace."); + } + auto info = *handle.info; + if (workspace_bytes && + rocfft_execution_info_set_work_buffer(info, static_cast(workspace), + static_cast(workspace_bytes)) != + rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", function, "Failed to set work buffer."); + } + } + + void free_internal_workspace_if_rqd(rocfft_handle& handle, const char* function) { + if (handle.buffer) { + if (hipFree(*handle.buffer) != hipSuccess) { + throw mkl::exception("dft/backends/rocfft", function, + "Failed to free internal buffer."); + } + handle.buffer = std::nullopt; + } + } + + virtual void set_workspace(scalar_type* usm_workspace) override { + std::int64_t total_workspace_bytes{ this->get_workspace_external_bytes() }; + this->external_workspace_helper_.set_workspace_throw(*this, usm_workspace); + if (handles[0].plan) { + free_internal_workspace_if_rqd(handles[0], "set_workspace"); + set_workspace_impl(handles[0], usm_workspace, total_workspace_bytes, "set_workspace"); + } + if (handles[1].plan) { + free_internal_workspace_if_rqd(handles[1], "set_workspace"); + set_workspace_impl(handles[1], usm_workspace, total_workspace_bytes, "set_workspace"); + } + } + + void set_buffer_workspace(rocfft_handle& handle, sycl::buffer& buffer_workspace) { + auto workspace_bytes = buffer_workspace.size() * sizeof(scalar_type); + if (buffer_workspace.size() == 0) { + return; // Nothing to do. + } + this->get_queue().submit([&](sycl::handler& cgh) { + auto workspace_acc = + buffer_workspace.template get_access(cgh); + cgh.host_task([=](sycl::interop_handle ih) { + auto workspace_native = reinterpret_cast( + ih.get_native_mem(workspace_acc)); + set_workspace_impl(handle, workspace_native, workspace_bytes, "set_workspace"); + }); + }); + this->get_queue().wait_and_throw(); + } + + virtual void set_workspace(sycl::buffer& buffer_workspace) override { + this->external_workspace_helper_.set_workspace_throw(*this, buffer_workspace); + std::size_t total_workspace_count = + static_cast(this->get_workspace_external_bytes()) / sizeof(scalar_type); + if (handles[0].plan) { + free_internal_workspace_if_rqd(handles[0], "set_workspace"); + set_buffer_workspace(handles[0], buffer_workspace); + } + if (handles[1].plan) { + free_internal_workspace_if_rqd(handles[1], "set_workspace"); + set_buffer_workspace(handles[1], buffer_workspace); + } + } + + std::int64_t get_plan_workspace_size_bytes(rocfft_plan_t* plan) { + // plan work buffer + if (plan == nullptr) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Missing internal rocFFT plan."); + } + std::size_t work_buf_size; + if (rocfft_plan_get_work_buffer_size(plan, &work_buf_size) != rocfft_status_success) { + throw mkl::exception("dft/backends/rocfft", __FUNCTION__, + "Failed to get work buffer size."); + } + return static_cast(work_buf_size); + } + + virtual std::int64_t get_workspace_external_bytes_impl() override { + std::int64_t size0 = handles[0].plan ? get_plan_workspace_size_bytes(*handles[0].plan) : 0; + std::int64_t size1 = handles[1].plan ? get_plan_workspace_size_bytes(*handles[1].plan) : 0; + return std::max(size0, size1); + }; + +#define BACKEND rocfft +#include "../backend_compute_signature.cxx" +#undef BACKEND +}; +} // namespace detail + +template +dft::detail::commit_impl* create_commit(const dft::detail::descriptor& desc, + sycl::queue& sycl_queue) { + return new detail::rocfft_commit(sycl_queue, desc.get_values()); +} + +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( + const dft::detail::descriptor&, + sycl::queue&); + +namespace detail { +template +std::array get_offsets_fwd(dft::detail::commit_impl* commit) { + return static_cast*>(commit)->get_offsets_fwd(); +} + +template +std::array get_offsets_bwd(dft::detail::commit_impl* commit) { + return static_cast*>(commit)->get_offsets_bwd(); +} + +template std::array +get_offsets_fwd( + dft::detail::commit_impl*); +template std::array +get_offsets_fwd( + dft::detail::commit_impl*); +template std::array +get_offsets_fwd( + dft::detail::commit_impl*); +template std::array +get_offsets_fwd( + dft::detail::commit_impl*); + +template std::array +get_offsets_bwd( + dft::detail::commit_impl*); +template std::array +get_offsets_bwd( + dft::detail::commit_impl*); +template std::array +get_offsets_bwd( + dft::detail::commit_impl*); +template std::array +get_offsets_bwd( + dft::detail::commit_impl*); + +} //namespace detail + +} // namespace oneapi::mkl::dft::rocfft diff --git a/src/dft/backends/rocfft/descriptor.cpp b/src/dft/backends/rocfft/descriptor.cpp new file mode 100644 index 000000000..83fdbe1dc --- /dev/null +++ b/src/dft/backends/rocfft/descriptor.cpp @@ -0,0 +1,51 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/descriptor.hpp" +#include "../../descriptor.cxx" + +#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp" + +namespace oneapi { +namespace mkl { +namespace dft { + +template +void descriptor::commit(backend_selector selector) { + if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) { + if (pimpl_) { + pimpl_->get_queue().wait(); + } + pimpl_.reset(rocfft::create_commit(*this, selector.get_queue())); + } + pimpl_->commit(values_); +} + +template void descriptor::commit( + backend_selector); +template void descriptor::commit( + backend_selector); +template void descriptor::commit( + backend_selector); +template void descriptor::commit( + backend_selector); + +} //namespace dft +} //namespace mkl +} //namespace oneapi diff --git a/src/dft/backends/rocfft/execute_helper.hpp b/src/dft/backends/rocfft/execute_helper.hpp new file mode 100644 index 000000000..4dff6831d --- /dev/null +++ b/src/dft/backends/rocfft/execute_helper.hpp @@ -0,0 +1,97 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_ +#define _ONEMKL_DFT_SRC_ROCFFT_EXECUTE_HELPER_HPP_ + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/dft/detail/commit_impl.hpp" +#include "oneapi/mkl/dft/detail/descriptor_impl.hpp" +#include "oneapi/mkl/dft/types.hpp" +#include "oneapi/mkl/exceptions.hpp" + +#include +#include + +namespace oneapi::mkl::dft::rocfft::detail { + +template +inline dft::detail::commit_impl *checked_get_commit( + dft::detail::descriptor &desc) { + auto commit_handle = dft::detail::get_commit(desc); + if (commit_handle == nullptr || commit_handle->get_backend() != backend::rocfft) { + throw mkl::invalid_argument("dft/backends/rocfft", "get_commit", + "DFT descriptor has not been commited for rocFFT"); + } + return commit_handle; +} + +/// Throw an mkl::invalid_argument if the runtime param in the descriptor does not match +/// the expected value. +template +inline auto expect_config(DescT &desc, const char *message) { + dft::config_value actual{ 0 }; + desc.get_value(Param, &actual); + if (actual != Expected) { + throw mkl::invalid_argument("dft/backends/rocfft", "expect_config", message); + } +} + +template +inline void *native_mem(sycl::interop_handle &ih, Acc &buf) { + return ih.get_native_mem(buf); +} + +inline hipStream_t setup_stream(const std::string &func, sycl::interop_handle &ih, + rocfft_execution_info info) { + auto stream = ih.get_native_queue(); + auto result = rocfft_execution_info_set_stream(info, stream); + if (result != rocfft_status_success) { + throw oneapi::mkl::exception( + "dft/backends/rocfft", func, + "rocfft_execution_info_set_stream returned " + std::to_string(result)); + } + return stream; +} + +inline void sync_checked(const std::string &func, hipStream_t stream) { + auto result = hipStreamSynchronize(stream); + if (result != hipSuccess) { + throw oneapi::mkl::exception("dft/backends/rocfft", func, + "hipStreamSynchronize returned " + std::to_string(result)); + } +} + +inline void execute_checked(const std::string &func, const rocfft_plan plan, void *in_buffer[], + void *out_buffer[], rocfft_execution_info info) { + auto result = rocfft_execute(plan, in_buffer, out_buffer, info); + if (result != rocfft_status_success) { + throw oneapi::mkl::exception("dft/backends/rocfft", func, + "rocfft_execute returned " + std::to_string(result)); + } +} + +} // namespace oneapi::mkl::dft::rocfft::detail + +#endif diff --git a/src/dft/backends/rocfft/forward.cpp b/src/dft/backends/rocfft/forward.cpp new file mode 100644 index 000000000..70d3d0f97 --- /dev/null +++ b/src/dft/backends/rocfft/forward.cpp @@ -0,0 +1,358 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/exceptions.hpp" + +#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp" +#include "oneapi/mkl/dft/descriptor.hpp" + +#include "execute_helper.hpp" +#include "rocfft_handle.hpp" + +#include +#include + +namespace oneapi::mkl::dft::rocfft { + +namespace detail { +//forward declaration +template +std::array get_offsets_fwd(dft::detail::commit_impl *commit); + +template +rocfft_plan get_fwd_plan(dft::detail::commit_impl *commit) { + return static_cast(commit->get_handle())[0].plan.value(); +} + +template +rocfft_execution_info get_fwd_info(dft::detail::commit_impl *commit) { + return static_cast(commit->get_handle())[0].info.value(); +} +} // namespace detail + +// BUFFER version + +//In-place transform +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &inout) { + const std::string func_name = "compute_forward(desc, inout)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + auto offsets = detail::get_offsets_fwd(commit); + + if constexpr (std::is_floating_point_v>) { + offsets[1] *= 2; // offset is supplied in complex but we offset scalar pointer + } + if (offsets[0] != offsets[1]) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "rocFFT requires input and output offsets (first value in strides) to be equal for in-place transforms!"); + } + + queue.submit([&](sycl::handler &cgh) { + auto inout_acc = inout.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, info); + + auto inout_native = reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, inout_acc)) + + offsets[0]); + detail::execute_checked(func_name, plan, &inout_native, nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &inout_re, + sycl::buffer, 1> &inout_im) { + const std::string func_name = "compute_forward(desc, inout_re, inout_im)"; + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + auto offsets = detail::get_offsets_fwd(commit); + + if (offsets[0] != offsets[1]) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "rocFFT requires input and output offsets (first value in strides) to be equal for in-place transforms!"); + } + + queue.submit([&](sycl::handler &cgh) { + auto inout_re_acc = inout_re.template get_access(cgh); + auto inout_im_acc = inout_im.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, info); + + std::array inout_native{ + reinterpret_cast(reinterpret_cast *>( + detail::native_mem(ih, inout_re_acc)) + + offsets[0]), + reinterpret_cast(reinterpret_cast *>( + detail::native_mem(ih, inout_im_acc)) + + offsets[0]) + }; + detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, sycl::buffer, 1> &in, + sycl::buffer, 1> &out) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + auto offsets = detail::get_offsets_fwd(commit); + + queue.submit([&](sycl::handler &cgh) { + auto in_acc = in.template get_access(cgh); + auto out_acc = out.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, in, out)"; + auto stream = detail::setup_stream(func_name, ih, info); + + auto in_native = reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, in_acc)) + + offsets[0]); + auto out_native = reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, out_acc)) + + offsets[1]); + detail::execute_checked(func_name, plan, &in_native, &out_native, info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT void compute_forward(descriptor_type &desc, + sycl::buffer, 1> &in_re, + sycl::buffer, 1> &in_im, + sycl::buffer, 1> &out_re, + sycl::buffer, 1> &out_im) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + auto offsets = detail::get_offsets_fwd(commit); + + queue.submit([&](sycl::handler &cgh) { + auto in_re_acc = in_re.template get_access(cgh); + auto in_im_acc = in_im.template get_access(cgh); + auto out_re_acc = out_re.template get_access(cgh); + auto out_im_acc = out_im.template get_access(cgh); + commit->add_buffer_workspace_dependency_if_rqd("compute_forward", cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, in_re, in_im, out_re, out_im)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array in_native{ + reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, in_re_acc)) + + offsets[0]), + reinterpret_cast( + reinterpret_cast *>(detail::native_mem(ih, in_im_acc)) + + offsets[0]) + }; + std::array out_native{ + reinterpret_cast(reinterpret_cast *>( + detail::native_mem(ih, out_re_acc)) + + offsets[1]), + reinterpret_cast(reinterpret_cast *>( + detail::native_mem(ih, out_im_acc)) + + offsets[1]) + }; + detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); + detail::sync_checked(func_name, stream); + }); + }); +} + +//USM version + +//In-place transform +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *inout, + const std::vector &deps) { + const std::string func_name = "compute_forward(desc, inout, deps)"; + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + auto offsets = detail::get_offsets_fwd(commit); + + if constexpr (std::is_floating_point_v>) { + offsets[1] *= 2; // offset is supplied in complex but we offset scalar pointer + } + if (offsets[0] != offsets[1]) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "rocFFT requires input and output offsets (first value in strides) to be equal for in-place transforms!"); + } + inout += offsets[0]; + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, info); + + void *inout_ptr = inout; + detail::execute_checked(func_name, plan, &inout_ptr, nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalar *inout_re, + scalar *inout_im, + const std::vector &deps) { + const std::string func_name = "compute_forward(desc, inout_re, inout_im, deps)"; + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + auto offsets = detail::get_offsets_fwd(commit); + + if (offsets[0] != offsets[1]) { + throw oneapi::mkl::unimplemented( + "DFT", func_name, + "rocFFT requires input and output offsets (first value in strides) to be equal for in-place transforms!"); + } + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + cgh.host_task([=](sycl::interop_handle ih) { + auto stream = detail::setup_stream(func_name, ih, info); + + std::array inout_native{ inout_re + offsets[0], inout_im + offsets[0] }; + detail::execute_checked(func_name, plan, inout_native.data(), nullptr, info); + detail::sync_checked(func_name, stream); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//Out-of-place transform +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, fwd *in, + bwd *out, + const std::vector &deps) { + detail::expect_config( + desc, "Unexpected value for placement"); + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + auto offsets = detail::get_offsets_fwd(commit); + + in += offsets[0]; + out += offsets[1]; + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = "compute_forward(desc, in, out, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + void *in_ptr = in; + void *out_ptr = out; + detail::execute_checked(func_name, plan, &in_ptr, &out_ptr, info); + detail::sync_checked(func_name, stream); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +//Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format +template +ONEMKL_EXPORT sycl::event compute_forward(descriptor_type &desc, scalar *in_re, + scalar *in_im, + scalar *out_re, + scalar *out_im, + const std::vector &deps) { + auto commit = detail::checked_get_commit(desc); + auto queue = commit->get_queue(); + auto plan = detail::get_fwd_plan(commit); + auto info = detail::get_fwd_info(commit); + auto offsets = detail::get_offsets_fwd(commit); + + sycl::event sycl_event = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(deps); + commit->depend_on_last_usm_workspace_event_if_rqd(cgh); + + cgh.host_task([=](sycl::interop_handle ih) { + const std::string func_name = + "compute_forward(desc, in_re, in_im, out_re, out_im, deps)"; + auto stream = detail::setup_stream(func_name, ih, info); + + std::array in_native{ in_re + offsets[0], in_im + offsets[0] }; + std::array out_native{ out_re + offsets[1], out_im + offsets[1] }; + detail::execute_checked(func_name, plan, in_native.data(), out_native.data(), info); + detail::sync_checked(func_name, stream); + }); + }); + commit->set_last_usm_workspace_event_if_rqd(sycl_event); + return sycl_event; +} + +// Template function instantiations +#include "dft/backends/backend_forward_instantiations.cxx" + +} // namespace oneapi::mkl::dft::rocfft diff --git a/src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp b/src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp new file mode 100644 index 000000000..c8f0e35c7 --- /dev/null +++ b/src/dft/backends/rocfft/mkl_dft_rocfft_wrappers.cpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/dft/detail/rocfft/onemkl_dft_rocfft.hpp" +#include "dft/function_table.hpp" + +#define WRAPPER_VERSION 1 +#define BACKEND rocfft + +extern "C" dft_function_table_t mkl_dft_table = { + WRAPPER_VERSION, +#include "dft/backends/backend_wrappers.cxx" +}; + +#undef WRAPPER_VERSION +#undef BACKEND diff --git a/src/dft/backends/rocfft/rocfft_handle.hpp b/src/dft/backends/rocfft/rocfft_handle.hpp new file mode 100644 index 000000000..ea4f44d68 --- /dev/null +++ b/src/dft/backends/rocfft/rocfft_handle.hpp @@ -0,0 +1,34 @@ +/******************************************************************************* +* Copyright Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_DFT_SRC_ROCFFT_ROCFFT_HANDLE_HPP_ +#define _ONEMKL_DFT_SRC_ROCFFT_ROCFFT_HANDLE_HPP_ + +#include + +struct rocfft_plan_t; +struct rocfft_execution_info_t; + +struct rocfft_handle { + std::optional plan = std::nullopt; + std::optional info = std::nullopt; + std::optional buffer = std::nullopt; +}; + +#endif diff --git a/src/dft/backends/stride_helper.hpp b/src/dft/backends/stride_helper.hpp new file mode 100644 index 000000000..6c3146c99 --- /dev/null +++ b/src/dft/backends/stride_helper.hpp @@ -0,0 +1,151 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _DFT_DETAIL_STRIDE_HELPER_HPP_ +#define _DFT_DETAIL_STRIDE_HELPER_HPP_ + +namespace oneapi::mkl::dft::detail { + +enum class stride_api { + INVALID, // Cannot choose: no valid choice + FB_STRIDES, // Use FWD_STRIDES and BWD_STRIDES + IO_STRIDES // Use INPUT_STRIDES and OUTPUT_STRIDES +}; + +/** Throw invalid_argument for stride_api::INVALID + * @param function Function name to include in exception. + * @param stride_choice The stride_api to check if INVALID. Default is INVALID. + * + * @throws invalid_argument on stride_api::INVALID. + */ +inline void throw_on_invalid_stride_api(const char* function, + stride_api stride_choice = stride_api::INVALID) { + if (stride_choice == stride_api::INVALID) { + throw mkl::invalid_argument( + "DFT", function, + "Invalid INPUT/OUTPUT or FWD/BACKWARD strides. API usage may have been mixed."); + } +} + +// Helper class for mapping input / output strides for backend DFTs to config values. +// Intended to be abused as required for each backend. +template +struct stride_vectors { + using stride_elem_t = StrideElemT; + using stride_vec_t = std::vector; + + // The stride API being used. + const stride_api stride_choice; + + // The storage for strides. vec_a is forward or input. + stride_vec_t vec_a, vec_b; + + // Input and output strides for forward and backward DFTs. + stride_vec_t &fwd_in, &fwd_out, &bwd_in, &bwd_out; + + // Input and output offsets for forward and backward DFTs. + StrideElemT offset_fwd_in, offset_fwd_out, offset_bwd_in, offset_bwd_out; + + /** Initialize the forward / backwards input and output strides for this object. + * @tparam ConfigT The config values type. + * @param config_values The DFT config values. + * @param stride_api The stride API choice. Must not be INVALID. + **/ + template + stride_vectors(const ConfigT& config_values, stride_api stride_choice) + : stride_choice(stride_choice), + fwd_in(vec_a), + fwd_out(vec_b), + bwd_in(stride_choice == stride_api::FB_STRIDES ? vec_b : vec_a), + bwd_out(stride_choice == stride_api::FB_STRIDES ? vec_a : vec_b) { + if (stride_choice == stride_api::INVALID) { + throw mkl::exception("DFT", "detail::stride_vector constructor", + "Internal error: invalid stride API"); + } + auto& v1 = stride_choice == stride_api::FB_STRIDES ? config_values.fwd_strides + : config_values.input_strides; + auto& v2 = stride_choice == stride_api::FB_STRIDES ? config_values.bwd_strides + : config_values.output_strides; + + vec_a.resize(v1.size()); + vec_b.resize(v2.size()); + for (std::size_t i{ 0 }; i < v1.size(); ++i) { // v1.size() == v2.size() + if constexpr (std::is_unsigned_v) { + if (v1[i] < 0 || v2[i] < 0) { + throw mkl::unimplemented("DFT", "commit", + "Backend does not support negative strides."); + } + } + vec_a[i] = static_cast(v1[i]); + vec_b[i] = static_cast(v2[i]); + } + offset_fwd_in = fwd_in[0]; + offset_fwd_out = fwd_out[0]; + offset_bwd_in = bwd_in[0]; + offset_bwd_out = bwd_out[0]; + } +}; + +/** Determines whether INPUT/OUTPUT strides, or FWD/BWD strides API is used. + * @tparam ConfigT The config values type. + * @param config_values The DFT config values. + * @returns Stride choice. INVALID if the choice could not be determined. + * + * @note This does not attempt to determine that the set strides are valid. + */ +template +inline stride_api get_stride_api(const ConfigT& config_values) { + auto n = config_values.dimensions.size(); + // Test if FWD/BWD strides look like they should be used. If yes, use them. + if (config_values.fwd_strides.size() == n + 1 && config_values.bwd_strides.size() == n + 1) { + auto all_zero_fwd = true; + auto all_zero_bwd = true; + // If INPUT or OUTPUT have been set, these will be zeroed. + for (auto v : config_values.fwd_strides) { + all_zero_fwd = v == 0 && all_zero_fwd; + } + for (auto v : config_values.bwd_strides) { + all_zero_bwd = v == 0 && all_zero_bwd; + } + if (!all_zero_fwd && !all_zero_bwd) { // Both must be non-zero. + return stride_api::FB_STRIDES; + } + } + // FWD/BWD invalid. Test INPUT/OUTPUT for validity. + if (config_values.input_strides.size() == n + 1 && + config_values.output_strides.size() == n + 1) { + auto all_zero_in = true; + auto all_zero_out = true; + // If FWD or BWD have been set, these will be zeroed. + for (auto v : config_values.input_strides) { + all_zero_in = v == 0 && all_zero_in; + } + for (auto v : config_values.output_strides) { + all_zero_out = v == 0 && all_zero_out; + } + if (!all_zero_in && !all_zero_out) { // Both must be non-zero. + return stride_api::IO_STRIDES; + } + } + return stride_api::INVALID; +} + +} // namespace oneapi::mkl::dft::detail + +#endif //_DFT_DETAIL_STRIDE_HELPER_HPP_ diff --git a/src/dft/descriptor.cxx b/src/dft/descriptor.cxx index f870ff677..a9acd3b9e 100644 --- a/src/dft/descriptor.cxx +++ b/src/dft/descriptor.cxx @@ -29,17 +29,20 @@ namespace dft { namespace detail { // Compute the default strides. Modifies real_strides and complex_strides arguments. -void compute_default_strides(const std::vector& dimensions, - std::vector& input_strides, - std::vector& output_strides) { - int rank = dimensions.size(); +inline void compute_default_strides(const std::vector& dimensions, + std::vector& fwd_strides, + std::vector& bwd_strides) { + auto rank = dimensions.size(); std::vector strides(rank + 1, 1); - for (int i = rank - 1; i > 0; --i) { + for (auto i = rank - 1; i > 0; --i) { strides[i] = strides[i + 1] * dimensions[i]; } strides[0] = 0; - output_strides = strides; - input_strides = std::move(strides); + // Fwd/Bwd strides and Input/Output strides being the same by default means + // that we don't have to specify if we default to using fwd/bwd strides or + // input/output strides. + bwd_strides = strides; + fwd_strides = std::move(strides); } template @@ -69,16 +72,23 @@ void descriptor::set_value(config_param param, ...) { case config_param::INPUT_STRIDES: detail::set_value(values_, va_arg(vl, std::int64_t*)); break; - case config_param::OUTPUT_STRIDES: { + case config_param::OUTPUT_STRIDES: detail::set_value(values_, va_arg(vl, std::int64_t*)); break; - } + case config_param::FWD_STRIDES: + detail::set_value(values_, va_arg(vl, std::int64_t*)); + break; + case config_param::BWD_STRIDES: + detail::set_value(values_, va_arg(vl, std::int64_t*)); + break; // VA arg promotes float args to double, so the following is always double: case config_param::FORWARD_SCALE: - detail::set_value(values_, va_arg(vl, double)); + detail::set_value(values_, + static_cast(va_arg(vl, double))); break; case config_param::BACKWARD_SCALE: - detail::set_value(values_, va_arg(vl, double)); + detail::set_value( + values_, static_cast(va_arg(vl, double))); break; case config_param::NUMBER_OF_TRANSFORMS: detail::set_value(values_, @@ -112,6 +122,12 @@ void descriptor::set_value(config_param param, ...) { case config_param::WORKSPACE: detail::set_value(values_, va_arg(vl, config_value)); break; + case config_param::WORKSPACE_PLACEMENT: + detail::set_value(values_, va_arg(vl, config_value)); + break; + case config_param::WORKSPACE_EXTERNAL_BYTES: + throw mkl::invalid_argument("DFT", "set_value", "Read-only parameter."); + break; case config_param::PACKED_FORMAT: detail::set_value(values_, va_arg(vl, config_value)); break; @@ -134,10 +150,12 @@ descriptor::descriptor(std::vector dimensions) { "Invalid dimension value (negative or 0)."); } } + compute_default_strides(dimensions, values_.fwd_strides, values_.bwd_strides); // Assume forward transform. - compute_default_strides(dimensions, values_.input_strides, values_.output_strides); - values_.bwd_scale = 1.0; - values_.fwd_scale = 1.0; + values_.input_strides = values_.fwd_strides; + values_.output_strides = values_.bwd_strides; + values_.bwd_scale = real_t(1.0); + values_.fwd_scale = real_t(1.0); values_.number_of_transforms = 1; values_.fwd_dist = 1; values_.bwd_dist = 1; @@ -146,6 +164,7 @@ descriptor::descriptor(std::vector dimensions) { values_.real_storage = config_value::REAL_REAL; values_.conj_even_storage = config_value::COMPLEX_COMPLEX; values_.workspace = config_value::ALLOW; + values_.workspace_placement = config_value::WORKSPACE_AUTOMATIC; values_.ordering = config_value::ORDERED; values_.transpose = false; values_.packed_format = config_value::CCE_FORMAT; @@ -157,12 +176,16 @@ descriptor::descriptor(std::int64_t length) : descriptor(std::vector{ length }) {} template -descriptor::~descriptor() {} +descriptor::descriptor(descriptor&& other) = default; + +template +descriptor& descriptor::operator=(descriptor&&) = default; + +template +descriptor::~descriptor() = default; template void descriptor::get_value(config_param param, ...) const { - int err = 0; - using real_t = std::conditional_t; va_list vl; va_start(vl, param); if (va_arg(vl, void*) == nullptr) { @@ -172,7 +195,9 @@ void descriptor::get_value(config_param param, ...) const { va_start(vl, param); switch (param) { case config_param::FORWARD_DOMAIN: *va_arg(vl, dft::domain*) = dom; break; - case config_param::DIMENSION: *va_arg(vl, std::int64_t*) = static_cast(values_.dimensions.size()); break; + case config_param::DIMENSION: + *va_arg(vl, std::int64_t*) = static_cast(values_.dimensions.size()); + break; case config_param::LENGTHS: std::copy(values_.dimensions.begin(), values_.dimensions.end(), va_arg(vl, std::int64_t*)); @@ -203,9 +228,30 @@ void descriptor::get_value(config_param param, ...) const { std::copy(values_.output_strides.begin(), values_.output_strides.end(), va_arg(vl, std::int64_t*)); break; + case config_param::FWD_STRIDES: + std::copy(values_.fwd_strides.begin(), values_.fwd_strides.end(), + va_arg(vl, std::int64_t*)); + break; + case config_param::BWD_STRIDES: + std::copy(values_.bwd_strides.begin(), values_.bwd_strides.end(), + va_arg(vl, std::int64_t*)); + break; case config_param::FWD_DISTANCE: *va_arg(vl, std::int64_t*) = values_.fwd_dist; break; case config_param::BWD_DISTANCE: *va_arg(vl, std::int64_t*) = values_.bwd_dist; break; case config_param::WORKSPACE: *va_arg(vl, config_value*) = values_.workspace; break; + case config_param::WORKSPACE_PLACEMENT: + *va_arg(vl, config_value*) = values_.workspace_placement; + break; + case config_param::WORKSPACE_EXTERNAL_BYTES: + if (!pimpl_) { + throw mkl::invalid_argument( + "DFT", "get_value", + "Cannot query WORKSPACE_EXTERNAL_BYTES on uncommitted descriptor."); + } + else { + *va_arg(vl, std::int64_t*) = pimpl_->get_workspace_external_bytes(); + } + break; case config_param::ORDERING: *va_arg(vl, config_value*) = values_.ordering; break; case config_param::TRANSPOSE: *va_arg(vl, int*) = values_.transpose; break; case config_param::PACKED_FORMAT: *va_arg(vl, config_value*) = values_.packed_format; break; @@ -218,6 +264,28 @@ void descriptor::get_value(config_param param, ...) const { va_end(vl); } +template +void descriptor::set_workspace(scalar_type* usm_workspace) { + if (pimpl_) { + return pimpl_->set_workspace(usm_workspace); + } + else { + throw mkl::uninitialized("DFT", "set_workspace", + "Can only set workspace on committed descriptor."); + } +} + +template +void descriptor::set_workspace(sycl::buffer& buffer_workspace) { + if (pimpl_) { + return pimpl_->set_workspace(buffer_workspace); + } + else { + throw mkl::uninitialized("DFT", "set_workspace", + "Can only set workspace on committed descriptor."); + } +} + template class descriptor; template class descriptor; template class descriptor; diff --git a/src/dft/descriptor_config_helper.hpp b/src/dft/descriptor_config_helper.hpp index cc059734f..dc8c97ac2 100644 --- a/src/dft/descriptor_config_helper.hpp +++ b/src/dft/descriptor_config_helper.hpp @@ -30,6 +30,19 @@ namespace mkl { namespace dft { namespace detail { +/** Helper: sets both input vectors to zeros. + * Used for enforcing consistency when using FWD/BWD_STRIDES and + * INPUT/OUTPUT_STRIDES. + */ +static void reset_strides_to_zero(std::vector& v1, std::vector& v2) { + for (auto& v : v1) { + v = 0; + } + for (auto& v : v2) { + v = 0; + } +} + /// Helper to get real type from precision. template struct real_helper; @@ -78,10 +91,14 @@ PARAM_TYPE_HELPER(config_param::OUTPUT_STRIDES, std::int64_t*) PARAM_TYPE_HELPER(config_param::FWD_DISTANCE, std::int64_t) PARAM_TYPE_HELPER(config_param::BWD_DISTANCE, std::int64_t) PARAM_TYPE_HELPER(config_param::WORKSPACE, config_value) +PARAM_TYPE_HELPER(config_param::WORKSPACE_PLACEMENT, config_value) +PARAM_TYPE_HELPER(config_param::WORKSPACE_EXTERNAL_BYTES, std::int64_t) PARAM_TYPE_HELPER(config_param::ORDERING, config_value) PARAM_TYPE_HELPER(config_param::TRANSPOSE, bool) PARAM_TYPE_HELPER(config_param::PACKED_FORMAT, config_value) PARAM_TYPE_HELPER(config_param::COMMIT_STATUS, config_value) +PARAM_TYPE_HELPER(config_param::FWD_STRIDES, std::int64_t*) +PARAM_TYPE_HELPER(config_param::BWD_STRIDES, std::int64_t*) #undef PARAM_TYPE_HELPER /** Set a value in dft_values, throwing on invalid args. @@ -98,7 +115,7 @@ void set_value(dft_values& vals, if (set_val == nullptr) { throw mkl::invalid_argument("DFT", "set_value", "Given nullptr."); } - for (int i{ 0 }; i < vals.dimensions.size(); ++i) { + for (std::size_t i{ 0 }; i < vals.dimensions.size(); ++i) { if (set_val[i] <= 0) { throw mkl::invalid_argument("DFT", "set_value", "Invalid length value (negative or 0)."); @@ -157,18 +174,23 @@ void set_value(dft_values& vals, "Placement must be inplace or not inplace."); } } +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" else if constexpr (Param == config_param::INPUT_STRIDES) { if (set_val == nullptr) { throw mkl::invalid_argument("DFT", "set_value", "Given nullptr."); } + reset_strides_to_zero(vals.fwd_strides, vals.bwd_strides); std::copy(set_val, set_val + vals.dimensions.size() + 1, vals.input_strides.begin()); } else if constexpr (Param == config_param::OUTPUT_STRIDES) { if (set_val == nullptr) { throw mkl::invalid_argument("DFT", "set_value", "Given nullptr."); } + reset_strides_to_zero(vals.fwd_strides, vals.bwd_strides); std::copy(set_val, set_val + vals.dimensions.size() + 1, vals.output_strides.begin()); } +#pragma clang diagnostic pop else if constexpr (Param == config_param::FWD_DISTANCE) { vals.fwd_dist = set_val; } @@ -183,6 +205,19 @@ void set_value(dft_values& vals, throw mkl::invalid_argument("DFT", "set_value", "Workspace must be allow or avoid."); } } + else if constexpr (Param == config_param::WORKSPACE_PLACEMENT) { + if (set_val == config_value::WORKSPACE_AUTOMATIC || + set_val == config_value::WORKSPACE_EXTERNAL) { + vals.workspace_placement = set_val; + } + else { + throw mkl::invalid_argument( + "DFT", "set_value", "Workspace must be WORKSPACE_AUTOMATIC or WORKSPACE_EXTERNAL."); + } + } + else if constexpr (Param == config_param::WORKSPACE_EXTERNAL_BYTES) { + throw mkl::invalid_argument("DFT", "set_value", "Read-only parameter."); + } else if constexpr (Param == config_param::ORDERING) { if (set_val == config_value::ORDERED || set_val == config_value::BACKWARD_SCRAMBLED) { vals.ordering = set_val; @@ -203,6 +238,20 @@ void set_value(dft_values& vals, throw mkl::invalid_argument("DFT", "set_value", "Packed format must be CCE."); } } + else if constexpr (Param == config_param::FWD_STRIDES) { + if (set_val == nullptr) { + throw mkl::invalid_argument("DFT", "set_value", "Given nullptr."); + } + reset_strides_to_zero(vals.input_strides, vals.output_strides); + std::copy(set_val, set_val + vals.dimensions.size() + 1, vals.fwd_strides.begin()); + } + else if constexpr (Param == config_param::BWD_STRIDES) { + if (set_val == nullptr) { + throw mkl::invalid_argument("DFT", "set_value", "Given nullptr."); + } + reset_strides_to_zero(vals.input_strides, vals.output_strides); + std::copy(set_val, set_val + vals.dimensions.size() + 1, vals.bwd_strides.begin()); + } } } // namespace detail diff --git a/src/dft/dft_loader.cpp b/src/dft/dft_loader.cpp index c3016ddc8..b0c421fb0 100644 --- a/src/dft/dft_loader.cpp +++ b/src/dft/dft_loader.cpp @@ -25,10 +25,7 @@ #include "dft/function_table.hpp" #include "oneapi/mkl/detail/get_device_id.hpp" -namespace oneapi { -namespace mkl { -namespace dft { -namespace detail { +namespace oneapi::mkl::dft::detail { static oneapi::mkl::detail::table_initializer function_tables; @@ -72,305 +69,4 @@ inline oneapi::mkl::device get_device(descriptor& desc, const char* f return get_device_id(get_commit(desc)->get_queue()); } -} // namespace detail - -#define ONEAPI_MKL_DFT_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - \ - /*Buffer version*/ \ - \ - /*In-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT void compute_forward, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & inout) { \ - detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_buffer_inplace_real_##EXT(desc, inout); \ - } \ - \ - /*In-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT void compute_forward, T_BACKWARD>( \ - dft::detail::descriptor & desc, sycl::buffer & inout) { \ - detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_buffer_inplace_complex_##EXT(desc, inout); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT void compute_forward, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & inout_re, \ - sycl::buffer & inout_im) { \ - detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_buffer_inplace_split_##EXT(desc, inout_re, inout_im); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_forward, T_FORWARD, T_BACKWARD>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_buffer_outofplace_##EXT(desc, in, out); \ - } \ - \ - /*Out-of-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_forward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_buffer_outofplace_real_##EXT(desc, in, out); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_forward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & in_re, \ - sycl::buffer & in_im, sycl::buffer & out_re, \ - sycl::buffer & out_im) { \ - detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_buffer_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im); \ - } \ - \ - /*USM version*/ \ - \ - /*In-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT sycl::event compute_forward, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * inout, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_usm_inplace_real_##EXT(desc, inout, dependencies); \ - } \ - \ - /*In-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_forward, T_BACKWARD>( \ - dft::detail::descriptor & desc, T_BACKWARD * inout, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_usm_inplace_complex_##EXT(desc, inout, dependencies); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT sycl::event compute_forward, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * inout_re, T_REAL * inout_im, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_usm_inplace_split_##EXT(desc, inout_re, inout_im, dependencies); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_forward, T_FORWARD, T_BACKWARD>( \ - dft::detail::descriptor & desc, T_FORWARD * in, T_BACKWARD * out, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_usm_outofplace_##EXT(desc, in, out, dependencies); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_forward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * in, T_REAL * out, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_usm_outofplace_real_##EXT(desc, in, out, dependencies); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_forward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * in_re, T_REAL * in_im, \ - T_REAL * out_re, T_REAL * out_im, const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_usm_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im, \ - dependencies); \ - } \ - \ - /*Buffer version*/ \ - \ - /*In-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT void compute_backward, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & inout) { \ - detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_buffer_inplace_real_##EXT(desc, inout); \ - } \ - \ - /*In-place transform - complex */ \ - template <> \ - ONEMKL_EXPORT void compute_backward, T_BACKWARD>( \ - dft::detail::descriptor & desc, sycl::buffer & inout) { \ - detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_buffer_inplace_complex_##EXT(desc, inout); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT void compute_backward, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & inout_re, \ - sycl::buffer & inout_im) { \ - detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_buffer_inplace_split_##EXT(desc, inout_re, inout_im); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_backward, T_BACKWARD, T_FORWARD>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_buffer_outofplace_##EXT(desc, in, out); \ - } \ - \ - /*Out-of-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_backward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - return detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_buffer_outofplace_real_##EXT(desc, in, out); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_backward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, sycl::buffer & in_re, \ - sycl::buffer & in_im, sycl::buffer & out_re, \ - sycl::buffer & out_im) { \ - detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_buffer_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im); \ - } \ - \ - /*USM version*/ \ - \ - /*In-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * inout, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_usm_inplace_real_##EXT(desc, inout, dependencies); \ - } \ - \ - /*In-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_BACKWARD>( \ - dft::detail::descriptor & desc, T_BACKWARD * inout, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_usm_inplace_complex_##EXT(desc, inout, dependencies); \ - } \ - \ - /*In-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * inout_re, T_REAL * inout_im, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_usm_inplace_split_##EXT(desc, inout_re, inout_im, dependencies); \ - } \ - \ - /*Out-of-place transform*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_BACKWARD, T_FORWARD>( \ - dft::detail::descriptor & desc, T_BACKWARD * in, T_FORWARD * out, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_usm_outofplace_##EXT(desc, in, out, dependencies); \ - } \ - \ - /*Out-of-place transform - real*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * in, T_REAL * out, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_usm_outofplace_real_##EXT(desc, in, out, dependencies); \ - } \ - \ - /*Out-of-place transform, using config_param::COMPLEX_STORAGE=config_value::REAL_REAL data format*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_REAL, T_REAL>( \ - dft::detail::descriptor & desc, T_REAL * in_re, T_REAL * in_im, \ - T_REAL * out_re, T_REAL * out_im, const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_usm_outofplace_split_##EXT(desc, in_re, in_im, out_re, out_im, \ - dependencies); \ - } - -// Signatures with forward_t=complex, backwards_t=complex are already instantiated for complex domain -// but not real domain. -#define ONEAPI_MKL_DFT_REAL_ONLY_SIGNATURES(EXT, PRECISION, T_COMPLEX) \ - /*Out-of-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_forward, T_COMPLEX, T_COMPLEX>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_buffer_outofplace_complex_##EXT(desc, in, out); \ - } \ - \ - /*Out-of-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_forward, T_COMPLEX, T_COMPLEX>( \ - dft::detail::descriptor & desc, T_COMPLEX * in, T_COMPLEX * out, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_forward")] \ - .compute_forward_usm_outofplace_complex_##EXT(desc, in, out, dependencies); \ - } \ - \ - /*Out-of-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT void \ - compute_backward, T_COMPLEX, T_COMPLEX>( \ - dft::detail::descriptor & desc, sycl::buffer & in, \ - sycl::buffer & out) { \ - detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_buffer_outofplace_complex_##EXT(desc, in, out); \ - } \ - \ - /*Out-of-place transform - complex*/ \ - template <> \ - ONEMKL_EXPORT sycl::event \ - compute_backward, T_COMPLEX, T_COMPLEX>( \ - dft::detail::descriptor & desc, T_COMPLEX * in, T_COMPLEX * out, \ - const std::vector& dependencies) { \ - return detail::function_tables[detail::get_device(desc, "compute_backward")] \ - .compute_backward_usm_outofplace_complex_##EXT(desc, in, out, dependencies); \ - } - -ONEAPI_MKL_DFT_SIGNATURES(f, dft::detail::precision::SINGLE, dft::detail::domain::REAL, float, - float, std::complex) -ONEAPI_MKL_DFT_REAL_ONLY_SIGNATURES(f, dft::detail::precision::SINGLE, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(c, dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX, float, - std::complex, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(d, dft::detail::precision::DOUBLE, dft::detail::domain::REAL, double, - double, std::complex) -ONEAPI_MKL_DFT_REAL_ONLY_SIGNATURES(d, dft::detail::precision::DOUBLE, std::complex) -ONEAPI_MKL_DFT_SIGNATURES(z, dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX, double, - std::complex, std::complex) -#undef ONEAPI_MKL_DFT_SIGNATURES - -} // namespace dft -} // namespace mkl -} // namespace oneapi +} // namespace oneapi::mkl::dft::detail diff --git a/src/dft/function_table.hpp b/src/dft/function_table.hpp index 589647993..9146f239e 100644 --- a/src/dft/function_table.hpp +++ b/src/dft/function_table.hpp @@ -57,111 +57,6 @@ typedef struct { const oneapi::mkl::dft::descriptor& desc, sycl::queue& sycl_queue); - -#define ONEAPI_MKL_DFT_BACKEND_SIGNATURES(EXT, PRECISION, DOMAIN, T_REAL, T_FORWARD, T_BACKWARD) \ - void (*compute_forward_buffer_inplace_real_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & inout); \ - void (*compute_forward_buffer_inplace_complex_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & inout); \ - void (*compute_forward_buffer_inplace_split_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & inout_re, sycl::buffer & inout_im); \ - void (*compute_forward_buffer_outofplace_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & in, sycl::buffer & out); \ - void (*compute_forward_buffer_outofplace_real_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & in, sycl::buffer & out); \ - void (*compute_forward_buffer_outofplace_complex_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & in, sycl::buffer & out); \ - void (*compute_forward_buffer_outofplace_split_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & in_re, sycl::buffer & in_im, \ - sycl::buffer & out_re, sycl::buffer & out_im); \ - sycl::event (*compute_forward_usm_inplace_real_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_REAL * inout, \ - const std::vector& dependencies); \ - sycl::event (*compute_forward_usm_inplace_complex_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_BACKWARD * inout, \ - const std::vector& dependencies); \ - sycl::event (*compute_forward_usm_inplace_split_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_REAL * inout_re, \ - T_REAL * inout_im, const std::vector& dependencies); \ - sycl::event (*compute_forward_usm_outofplace_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_FORWARD * in, \ - T_BACKWARD * out, const std::vector& dependencies); \ - sycl::event (*compute_forward_usm_outofplace_real_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_REAL * in, T_REAL * out, \ - const std::vector& dependencies); \ - sycl::event (*compute_forward_usm_outofplace_complex_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_BACKWARD * in, \ - T_BACKWARD * out, const std::vector& dependencies); \ - sycl::event (*compute_forward_usm_outofplace_split_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_REAL * in_re, \ - T_REAL * in_im, T_REAL * out_re, T_REAL * out_im, \ - const std::vector& dependencies); \ - void (*compute_backward_buffer_inplace_real_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & inout); \ - void (*compute_backward_buffer_inplace_complex_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & inout); \ - void (*compute_backward_buffer_inplace_split_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & inout_re, sycl::buffer & inout_im); \ - void (*compute_backward_buffer_outofplace_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & in, sycl::buffer & out); \ - void (*compute_backward_buffer_outofplace_real_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & in, sycl::buffer & out); \ - void (*compute_backward_buffer_outofplace_complex_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & in, sycl::buffer & out); \ - void (*compute_backward_buffer_outofplace_split_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, \ - sycl::buffer & in_re, sycl::buffer & in_im, \ - sycl::buffer & out_re, sycl::buffer & out_im); \ - sycl::event (*compute_backward_usm_inplace_real_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_REAL * inout, \ - const std::vector& dependencies); \ - sycl::event (*compute_backward_usm_inplace_complex_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_BACKWARD * inout, \ - const std::vector& dependencies); \ - sycl::event (*compute_backward_usm_inplace_split_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_REAL * inout_re, \ - T_REAL * inout_im, const std::vector& dependencies); \ - sycl::event (*compute_backward_usm_outofplace_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_BACKWARD * in, \ - T_FORWARD * out, const std::vector& dependencies); \ - sycl::event (*compute_backward_usm_outofplace_real_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_REAL * in, T_REAL * out, \ - const std::vector& dependencies); \ - sycl::event (*compute_backward_usm_outofplace_complex_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_BACKWARD * in, \ - T_BACKWARD * out, const std::vector& dependencies); \ - sycl::event (*compute_backward_usm_outofplace_split_##EXT)( \ - oneapi::mkl::dft::detail::descriptor & desc, T_REAL * in_re, \ - T_REAL * in_im, T_REAL * out_re, T_REAL * out_im, \ - const std::vector& dependencies); - - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(f, oneapi::mkl::dft::detail::precision::SINGLE, - oneapi::mkl::dft::detail::domain::REAL, float, float, - std::complex) - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(c, oneapi::mkl::dft::detail::precision::SINGLE, - oneapi::mkl::dft::detail::domain::COMPLEX, float, - std::complex, std::complex) - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(d, oneapi::mkl::dft::detail::precision::DOUBLE, - oneapi::mkl::dft::detail::domain::REAL, double, double, - std::complex) - ONEAPI_MKL_DFT_BACKEND_SIGNATURES(z, oneapi::mkl::dft::detail::precision::DOUBLE, - oneapi::mkl::dft::detail::domain::COMPLEX, double, - std::complex, std::complex) - -#undef ONEAPI_MKL_DFT_BACKEND_SIGNATURES } dft_function_table_t; #endif //_DFT_FUNCTION_TABLE_HPP_ diff --git a/src/include/dtype_string.hpp b/src/include/dtype_string.hpp new file mode 100644 index 000000000..6f2a87feb --- /dev/null +++ b/src/include/dtype_string.hpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2020-2021 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_ERROR_HELPER_HPP_ +#define _ONEMKL_ERROR_HELPER_HPP_ + +#include + +template +inline const std::string dtype_string(); +template <> +inline const std::string dtype_string() { + return "float"; +} +template <> +inline const std::string dtype_string() { + return "double"; +} +template <> +inline const std::string dtype_string() { + return "half"; +} +template <> +inline const std::string dtype_string>() { + return "complex"; +} +template <> +inline const std::string dtype_string>() { + return "complex"; +} +template <> +inline const std::string dtype_string() { + return "int32"; +} +template <> +inline const std::string dtype_string() { + return "int8"; +} + +#endif //_ONEMKL_ERROR_HELPER_HPP_ diff --git a/src/lapack/CMakeLists.txt b/src/lapack/CMakeLists.txt index b322d06a2..524edde03 100644 --- a/src/lapack/CMakeLists.txt +++ b/src/lapack/CMakeLists.txt @@ -29,6 +29,7 @@ target_include_directories(onemkl_lapack ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/src/include ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} $ ) diff --git a/src/lapack/backends/CMakeLists.txt b/src/lapack/backends/CMakeLists.txt index 8085ba125..636f6728f 100644 --- a/src/lapack/backends/CMakeLists.txt +++ b/src/lapack/backends/CMakeLists.txt @@ -17,6 +17,9 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== +add_custom_target(onemkl_backend_libs_lapack) +add_dependencies(onemkl_backend_libs onemkl_backend_libs_lapack) + if(ENABLE_MKLCPU_BACKEND) add_subdirectory(mklcpu) endif() diff --git a/src/lapack/backends/cusolver/CMakeLists.txt b/src/lapack/backends/cusolver/CMakeLists.txt index e40119dfe..dfd1267d7 100644 --- a/src/lapack/backends/cusolver/CMakeLists.txt +++ b/src/lapack/backends/cusolver/CMakeLists.txt @@ -20,21 +20,28 @@ set(LIB_NAME onemkl_lapack_cusolver) set(LIB_OBJ ${LIB_NAME}_obj) find_package(cuSOLVER REQUIRED) +find_package(cuBLAS REQUIRED) set(SOURCES cusolver_lapack.cpp cusolver_batch.cpp $<$:cusolver_scope_handle.cpp > $<$: cusolver_wrappers.cpp>) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemkl_backend_libs_lapack ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/src + ${PROJECT_SOURCE_DIR}/src/blas/backends/cublas ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} ) target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ONEMKL::cuSOLVER::cuSOLVER) +target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + ONEMKL::cuSOLVER::cuSOLVER + ONEMKL::cuBLAS::cuBLAS) target_compile_features(${LIB_OBJ} PUBLIC cxx_std_11) set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/src/lapack/backends/cusolver/cusolver_batch.cpp b/src/lapack/backends/cusolver/cusolver_batch.cpp index 57b9f4a88..59fa47f84 100644 --- a/src/lapack/backends/cusolver/cusolver_batch.cpp +++ b/src/lapack/backends/cusolver/cusolver_batch.cpp @@ -16,6 +16,7 @@ * limitations under the License. * **************************************************************************/ +#include "cublas_helper.hpp" #include "cusolver_helper.hpp" #include "cusolver_task.hpp" @@ -76,31 +77,112 @@ GEQRF_STRIDED_BATCH_LAUNCHER(std::complex, cusolverDnZgeqrf) #undef GEQRF_STRIDED_BATCH_LAUNCHER -void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, - std::int64_t stride_a, sycl::buffer &ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri_batch"); -} -void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, - std::int64_t stride_a, sycl::buffer &ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri_batch"); -} -void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer> &a, - std::int64_t lda, std::int64_t stride_a, sycl::buffer &ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - sycl::buffer> &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri_batch"); -} -void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer> &a, - std::int64_t lda, std::int64_t stride_a, sycl::buffer &ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - sycl::buffer> &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri_batch"); +template +inline void getri_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t n, + sycl::buffer &a, std::int64_t lda, std::int64_t stride_a, + sycl::buffer &ipiv, std::int64_t stride_ipiv, + std::int64_t batch_size, sycl::buffer &scratchpad, + std::int64_t scratchpad_size) { + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, lda, stride_a, stride_ipiv, batch_size, scratchpad_size); + + std::uint64_t ipiv32_size = n * batch_size; + sycl::buffer ipiv32(sycl::range<1>{ ipiv32_size }); + sycl::buffer devInfo{ batch_size }; + + queue.submit([&](sycl::handler &cgh) { + auto ipiv_acc = sycl::accessor{ ipiv, cgh, sycl::read_only }; + auto ipiv32_acc = sycl::accessor{ ipiv32, cgh, sycl::write_only }; + cgh.parallel_for(sycl::range<1>{ ipiv32_size }, [=](sycl::id<1> index) { + ipiv32_acc[index] = static_cast(ipiv_acc[(index / n) * stride_ipiv + index % n]); + }); + }); + + // getri_batched is contained within cublas, not cusolver. For this reason + // we need to use cublas types instead of cusolver types (as is needed for + // other lapack routines) + queue.submit([&](sycl::handler &cgh) { + using blas::cublas::cublas_error; + + sycl::accessor a_acc{ a, cgh, sycl::read_only }; + sycl::accessor scratch_acc{ scratchpad, cgh, sycl::write_only }; + sycl::accessor ipiv32_acc{ ipiv32, cgh }; + sycl::accessor devInfo_acc{ devInfo, cgh, sycl::write_only }; + + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + cublasStatus_t err; + CUresult cuda_result; + cublasHandle_t cublas_handle; + CUBLAS_ERROR_FUNC(cublasCreate, err, &cublas_handle); + CUstream cu_stream = sycl::get_native(queue); + CUBLAS_ERROR_FUNC(cublasSetStream, err, cublas_handle, cu_stream); + + auto a_ = sc.get_mem(a_acc); + auto scratch_ = sc.get_mem(scratch_acc); + auto ipiv32_ = sc.get_mem(ipiv32_acc); + auto info_ = sc.get_mem(devInfo_acc); + + CUdeviceptr a_dev; + cuDataType **a_batched = create_ptr_list_from_stride(a_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &a_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, a_dev, a_batched, sizeof(T *) * batch_size); + auto **a_dev_ = reinterpret_cast(a_dev); + + CUdeviceptr scratch_dev; + cuDataType **scratch_batched = + create_ptr_list_from_stride(scratch_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &scratch_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, scratch_dev, scratch_batched, + sizeof(T *) * batch_size); + auto **scratch_dev_ = reinterpret_cast(scratch_dev); + + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32_, + scratch_dev_, lda, info_, batch_size) + + free(a_batched); + free(scratch_batched); + cuMemFree(a_dev); + cuMemFree(scratch_dev); + }); + }); + + // The inverted matrices stored in scratch_ need to be stored in a_ + queue.submit([&](sycl::handler &cgh) { + sycl::accessor a_acc{ a, cgh, sycl::write_only }; + sycl::accessor scratch_acc{ scratchpad, cgh, sycl::read_only }; + cgh.parallel_for(sycl::range<1>{ static_cast( + sycl::max(stride_a * batch_size, lda * n * batch_size)) }, + [=](sycl::id<1> index) { a_acc[index] = scratch_acc[index]; }); + }); + + queue.submit([&](sycl::handler &cgh) { + sycl::accessor ipiv32_acc{ ipiv32, cgh, sycl::read_only }; + sycl::accessor ipiv_acc{ ipiv, cgh, sycl::write_only }; + cgh.parallel_for(sycl::range<1>{ static_cast(ipiv32_size) }, + [=](sycl::id<1> index) { + ipiv_acc[(index / n) * stride_ipiv + index % n] = + static_cast(ipiv32_acc[index]); + }); + }); } +#define GETRI_STRIDED_BATCH_LAUNCHER(TYPE, CUSOLVER_ROUTINE) \ + void getri_batch(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, \ + std::int64_t stride_a, sycl::buffer &ipiv, \ + std::int64_t stride_ipiv, std::int64_t batch_size, \ + sycl::buffer &scratchpad, std::int64_t scratchpad_size) { \ + return getri_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, n, a, lda, stride_a, ipiv, \ + stride_ipiv, batch_size, scratchpad, scratchpad_size); \ + } + +GETRI_STRIDED_BATCH_LAUNCHER(float, cublasSgetriBatched) +GETRI_STRIDED_BATCH_LAUNCHER(double, cublasDgetriBatched) +GETRI_STRIDED_BATCH_LAUNCHER(std::complex, cublasCgetriBatched) +GETRI_STRIDED_BATCH_LAUNCHER(std::complex, cublasZgetriBatched) + +#undef GETRI_STRIDED_BATCH_LAUNCHER + template inline void getrs_batch(const char *func_name, Func func, sycl::queue &queue, oneapi::mkl::transpose trans, std::int64_t n, std::int64_t nrhs, @@ -459,10 +541,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m, n, lda, stride_a, stride_tau, batch_size, scratchpad_size); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -513,10 +592,7 @@ inline sycl::event geqrf_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m[i], n[i], lda[i], group_sizes[i]); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -574,10 +650,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -659,10 +732,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu int *devInfo = (int *)malloc_device(sizeof(int) * batch_size, queue); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -701,10 +771,7 @@ inline sycl::event getrf_batch(const char *func_name, Func func, sycl::queue &qu // Enqueue free memory sycl::event done_freeing = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = casting_dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(casting_dependencies[i]); - } + cgh.depends_on(casting_dependencies); cgh.host_task([=](sycl::interop_handle ih) { for (int64_t global_id = 0; global_id < batch_size; ++global_id) sycl::free(ipiv32[global_id], queue); @@ -736,32 +803,108 @@ GETRF_BATCH_LAUNCHER_USM(std::complex, cusolverDnZgetrf) #undef GETRS_BATCH_LAUNCHER_USM -sycl::event getri_batch(sycl::queue &queue, std::int64_t n, float *a, std::int64_t lda, - std::int64_t stride_a, std::int64_t *ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri_batch"); -} -sycl::event getri_batch(sycl::queue &queue, std::int64_t n, double *a, std::int64_t lda, - std::int64_t stride_a, std::int64_t *ipiv, std::int64_t stride_ipiv, - std::int64_t batch_size, double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri_batch"); -} -sycl::event getri_batch(sycl::queue &queue, std::int64_t n, std::complex *a, - std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri_batch"); -} -sycl::event getri_batch(sycl::queue &queue, std::int64_t n, std::complex *a, +template +sycl::event getri_batch(const char *func_name, Func func, sycl::queue &queue, std::int64_t n, T *a, std::int64_t lda, std::int64_t stride_a, std::int64_t *ipiv, - std::int64_t stride_ipiv, std::int64_t batch_size, - std::complex *scratchpad, std::int64_t scratchpad_size, + std::int64_t stride_ipiv, std::int64_t batch_size, T *scratchpad, + std::int64_t scratchpad_size, const std::vector &dependencies) { - throw unimplemented("lapack", "getri_batch"); + using cuDataType = typename CudaEquivalentType::Type; + + overflow_check(n, lda, stride_a, stride_ipiv, batch_size, scratchpad_size); + + std::uint64_t ipiv32_size = n * batch_size; + int *ipiv32 = sycl::malloc_device(ipiv32_size, queue); + int *devInfo = sycl::malloc_device(batch_size, queue); + + sycl::event done_casting = queue.submit([&](sycl::handler &cgh) { + cgh.parallel_for( + sycl::range<1>{ static_cast(ipiv32_size) }, [=](sycl::id<1> index) { + ipiv32[index] = static_cast(ipiv[(index / n) * stride_ipiv + index % n]); + }); + }); + + // getri_batched is contained within cublas, not cusolver. For this reason + // we need to use cublas types instead of cusolver types (as is needed for + // other lapack routines) + auto done = queue.submit([&](sycl::handler &cgh) { + using blas::cublas::cublas_error; + + cgh.depends_on(done_casting); + cgh.depends_on(dependencies); + + onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { + cublasStatus_t err; + CUresult cuda_result; + cublasHandle_t cublas_handle; + CUBLAS_ERROR_FUNC(cublasCreate, err, &cublas_handle); + CUstream cu_stream = sycl::get_native(queue); + CUBLAS_ERROR_FUNC(cublasSetStream, err, cublas_handle, cu_stream); + + CUdeviceptr a_dev; + auto *a_ = reinterpret_cast(a); + cuDataType **a_batched = create_ptr_list_from_stride(a_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &a_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, a_dev, a_batched, sizeof(T *) * batch_size); + auto **a_dev_ = reinterpret_cast(a_dev); + + CUdeviceptr scratch_dev; + auto *scratch_ = reinterpret_cast(scratchpad); + cuDataType **scratch_batched = + create_ptr_list_from_stride(scratch_, stride_a, batch_size); + CUDA_ERROR_FUNC(cuMemAlloc, cuda_result, &scratch_dev, sizeof(T *) * batch_size); + CUDA_ERROR_FUNC(cuMemcpyHtoD, cuda_result, scratch_dev, scratch_batched, + sizeof(T *) * batch_size); + auto **scratch_dev_ = reinterpret_cast(scratch_dev); + + CUBLAS_ERROR_FUNC_T_SYNC(func_name, func, err, cublas_handle, n, a_dev_, lda, ipiv32, + scratch_dev_, lda, devInfo, batch_size) + + free(a_batched); + free(scratch_batched); + cuMemFree(a_dev); + cuMemFree(scratch_dev); + }); + }); + + // The inverted matrices stored in scratch_ need to be stored in a_ + auto copy1 = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + cgh.parallel_for( + sycl::range<1>{ static_cast(stride_a * (batch_size - 1) + lda * n) }, + [=](sycl::id<1> index) { a[index] = scratchpad[index]; }); + }); + + auto copy2 = queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); + cgh.parallel_for( + sycl::range<1>{ static_cast(ipiv32_size) }, [=](sycl::id<1> index) { + ipiv[(index / n) * stride_ipiv + index % n] = static_cast(ipiv32[index]); + }); + }); + copy1.wait(); + copy2.wait(); + sycl::free(ipiv32, queue); + sycl::free(devInfo, queue); + return done; } + +#define GETRI_BATCH_LAUNCHER_USM(TYPE, CUSOLVER_ROUTINE) \ + sycl::event getri_batch( \ + sycl::queue &queue, std::int64_t n, TYPE *a, std::int64_t lda, std::int64_t stride_a, \ + std::int64_t *ipiv, std::int64_t stride_ipiv, std::int64_t batch_size, TYPE *scratchpad, \ + std::int64_t scratchpad_size, const std::vector &dependencies) { \ + return getri_batch(#CUSOLVER_ROUTINE, CUSOLVER_ROUTINE, queue, n, a, lda, stride_a, ipiv, \ + stride_ipiv, batch_size, scratchpad, scratchpad_size, dependencies); \ + } + +GETRI_BATCH_LAUNCHER_USM(float, cublasSgetriBatched) +GETRI_BATCH_LAUNCHER_USM(double, cublasDgetriBatched) +GETRI_BATCH_LAUNCHER_USM(std::complex, cublasCgetriBatched) +GETRI_BATCH_LAUNCHER_USM(std::complex, cublasZgetriBatched) + +#undef GETRI_BATCH_LAUNCHER_USM + sycl::event getri_batch(sycl::queue &queue, std::int64_t *n, float **a, std::int64_t *lda, std::int64_t **ipiv, std::int64_t group_count, std::int64_t *group_sizes, float *scratchpad, std::int64_t scratchpad_size, @@ -814,10 +957,7 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu }); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); cgh.depends_on(done_casting); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -902,13 +1042,8 @@ inline sycl::event getrs_batch(const char *func_name, Func func, sycl::queue &qu } auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } - for (int64_t i = 0; i < batch_size; i++) { - cgh.depends_on(casting_dependencies[i]); - } + cgh.depends_on(dependencies); + cgh.depends_on(casting_dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); @@ -967,10 +1102,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -1020,10 +1152,7 @@ inline sycl::event orgqr_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -1074,10 +1203,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(n, lda, stride_a, batch_size, scratchpad_size); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); CUdeviceptr a_dev; @@ -1135,10 +1261,7 @@ inline sycl::event potrf_batch(const char *func_name, Func func, sycl::queue &qu } auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); int64_t offset = 0; @@ -1199,10 +1322,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu throw unimplemented("lapack", "potrs_batch", "cusolver potrs_batch only supports nrhs = 1"); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); CUresult cuda_result; @@ -1283,10 +1403,7 @@ inline sycl::event potrs_batch(const char *func_name, Func func, sycl::queue &qu queue.submit([&](sycl::handler &h) { h.memcpy(b_dev, b, batch_size * sizeof(T *)); }); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); cgh.depends_on(done_cpy_a); cgh.depends_on(done_cpy_b); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { @@ -1340,10 +1457,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m, n, k, lda, stride_a, stride_tau, batch_size, scratchpad_size); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -1393,10 +1507,7 @@ inline sycl::event ungqr_batch(const char *func_name, Func func, sycl::queue &qu overflow_check(m[i], n[i], k[i], lda[i], group_sizes[i]); auto done = queue.submit([&](sycl::handler &cgh) { - int64_t num_events = dependencies.size(); - for (int64_t i = 0; i < num_events; i++) { - cgh.depends_on(dependencies[i]); - } + cgh.depends_on(dependencies); onemkl_cusolver_host_task(cgh, queue, [=](CusolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = reinterpret_cast(a); @@ -1472,35 +1583,22 @@ GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex, cusolverDnZgetrf_buff #undef GETRF_STRIDED_BATCH_LAUNCHER_SCRATCH -template <> -std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t n, - std::int64_t lda, std::int64_t stride_a, - std::int64_t stride_ipiv, std::int64_t batch_size) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t n, - std::int64_t lda, std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t batch_size) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size>(sycl::queue &queue, std::int64_t n, - std::int64_t lda, - std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t batch_size) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size>(sycl::queue &queue, std::int64_t n, - std::int64_t lda, - std::int64_t stride_a, - std::int64_t stride_ipiv, - std::int64_t batch_size) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} +// Scratch memory needs to be the same size as a +#define GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t getri_batch_scratchpad_size( \ + sycl::queue & queue, std::int64_t n, std::int64_t lda, std::int64_t stride_a, \ + std::int64_t stride_ipiv, std::int64_t batch_size) { \ + assert(stride_a >= lda * n && "A matrices must not overlap"); \ + return stride_a * (batch_size - 1) + lda * n; \ + } + +GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(float) +GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(double) +GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) +GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH(std::complex) + +#undef GETRI_STRIDED_BATCH_LAUNCHER_SCRATCH // cusolverDnXgetrs does not use scratchpad memory #define GETRS_STRIDED_BATCH_LAUNCHER_SCRATCH(TYPE) \ @@ -1696,32 +1794,26 @@ GETRF_GROUP_LAUNCHER_SCRATCH(std::complex, cusolverDnZgetrf_bufferSize) #undef GETRF_GROUP_LAUNCHER_SCRATCH -template <> -std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t *n, - std::int64_t *lda, std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size(sycl::queue &queue, std::int64_t *n, - std::int64_t *lda, std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *n, - std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} -template <> -std::int64_t getri_batch_scratchpad_size>(sycl::queue &queue, std::int64_t *n, - std::int64_t *lda, - std::int64_t group_count, - std::int64_t *group_sizes) { - throw unimplemented("lapack", "getri_batch_scratchpad_size"); -} +#define GETRI_GROUP_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t getri_batch_scratchpad_size(sycl::queue & queue, std::int64_t * n, \ + std::int64_t * lda, std::int64_t group_count, \ + std::int64_t * group_sizes) { \ + std::int64_t max_scratch_sz = 0; \ + for (auto group_id = 0; group_id < group_count; ++group_id) { \ + auto scratch_sz = lda[group_id] * n[group_id]; \ + if (scratch_sz > max_scratch_sz) \ + max_scratch_sz = scratch_sz; \ + } \ + return max_scratch_sz; \ + } + +GETRI_GROUP_LAUNCHER_SCRATCH(float) +GETRI_GROUP_LAUNCHER_SCRATCH(double) +GETRI_GROUP_LAUNCHER_SCRATCH(std::complex) +GETRI_GROUP_LAUNCHER_SCRATCH(std::complex) + +#undef GETRI_GROUP_LAUNCHER_SCRATCH #define GETRS_GROUP_LAUNCHER_SCRATCH(TYPE) \ template <> \ diff --git a/src/lapack/backends/cusolver/cusolver_lapack.cpp b/src/lapack/backends/cusolver/cusolver_lapack.cpp index 4fbdccc72..0c7aaefc8 100644 --- a/src/lapack/backends/cusolver/cusolver_lapack.cpp +++ b/src/lapack/backends/cusolver/cusolver_lapack.cpp @@ -195,26 +195,19 @@ GETRF_LAUNCHER(std::complex, cusolverDnZgetrf) #undef GETRF_LAUNCHER -void getri(sycl::queue &queue, std::int64_t n, sycl::buffer> &a, - std::int64_t lda, sycl::buffer &ipiv, - sycl::buffer> &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri"); -} -void getri(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, - sycl::buffer &ipiv, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri"); -} -void getri(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, - sycl::buffer &ipiv, sycl::buffer &scratchpad, - std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri"); -} -void getri(sycl::queue &queue, std::int64_t n, sycl::buffer> &a, - std::int64_t lda, sycl::buffer &ipiv, - sycl::buffer> &scratchpad, std::int64_t scratchpad_size) { - throw unimplemented("lapack", "getri"); -} +#define GETRI_LAUNCHER(TYPE) \ + void getri(sycl::queue &queue, std::int64_t n, sycl::buffer &a, std::int64_t lda, \ + sycl::buffer &ipiv, sycl::buffer &scratchpad, \ + std::int64_t scratchpad_size) { \ + return getri_batch(queue, n, a, lda, lda * n, ipiv, n, 1, scratchpad, scratchpad_size); \ + } + +GETRI_LAUNCHER(float) +GETRI_LAUNCHER(double) +GETRI_LAUNCHER(std::complex) +GETRI_LAUNCHER(std::complex) + +#undef GETRI_LAUNCHER // cusolverDnXgetrs does not use scratchpad memory template @@ -1380,26 +1373,20 @@ GETRF_LAUNCHER_USM(std::complex, cusolverDnZgetrf) #undef GETRF_LAUNCHER_USM -sycl::event getri(sycl::queue &queue, std::int64_t n, std::complex *a, std::int64_t lda, - std::int64_t *ipiv, std::complex *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri"); -} -sycl::event getri(sycl::queue &queue, std::int64_t n, double *a, std::int64_t lda, - std::int64_t *ipiv, double *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri"); -} -sycl::event getri(sycl::queue &queue, std::int64_t n, float *a, std::int64_t lda, - std::int64_t *ipiv, float *scratchpad, std::int64_t scratchpad_size, - const std::vector &dependencies) { - throw unimplemented("lapack", "getri"); -} -sycl::event getri(sycl::queue &queue, std::int64_t n, std::complex *a, std::int64_t lda, - std::int64_t *ipiv, std::complex *scratchpad, - std::int64_t scratchpad_size, const std::vector &dependencies) { - throw unimplemented("lapack", "getri"); -} +#define GETRI_LAUNCHER_USM(TYPE) \ + sycl::event getri(sycl::queue &queue, std::int64_t n, TYPE *a, std::int64_t lda, \ + std::int64_t *ipiv, TYPE *scratchpad, std::int64_t scratchpad_size, \ + const std::vector &dependencies) { \ + return getri_batch(queue, n, a, lda, lda * n, ipiv, n, 1, scratchpad, scratchpad_size, \ + dependencies); \ + } + +GETRI_LAUNCHER_USM(float) +GETRI_LAUNCHER_USM(double) +GETRI_LAUNCHER_USM(std::complex) +GETRI_LAUNCHER_USM(std::complex) + +#undef GETRI_LAUNCHER_USM // cusolverDnXgetrs does not use scratchpad memory template @@ -2471,6 +2458,7 @@ inline void gebrd_scratchpad_size(const char *func_name, Func func, sycl::queue CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, scratch_size); }); }); + queue.wait(); } #define GEBRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2522,6 +2510,7 @@ inline void geqrf_scratchpad_size(const char *func_name, Func func, sycl::queue CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); }); }); + queue.wait(); } #define GEQRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2553,6 +2542,7 @@ inline void gesvd_scratchpad_size(const char *func_name, Func func, sycl::queue CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, scratch_size); }); }); + queue.wait(); } #define GESVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2584,6 +2574,7 @@ inline void getrf_scratchpad_size(const char *func_name, Func func, sycl::queue CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, nullptr, lda, scratch_size); }); }); + queue.wait(); } #define GETRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2603,24 +2594,19 @@ GETRF_LAUNCHER_SCRATCH(std::complex, cusolverDnZgetrf_bufferSize) #undef GETRF_LAUNCHER_SCRATCH -template <> -std::int64_t getri_scratchpad_size(sycl::queue &queue, std::int64_t n, std::int64_t lda) { - throw unimplemented("lapack", "getri_scratchpad_size"); -} -template <> -std::int64_t getri_scratchpad_size(sycl::queue &queue, std::int64_t n, std::int64_t lda) { - throw unimplemented("lapack", "getri_scratchpad_size"); -} -template <> -std::int64_t getri_scratchpad_size>(sycl::queue &queue, std::int64_t n, - std::int64_t lda) { - throw unimplemented("lapack", "getri_scratchpad_size"); -} -template <> -std::int64_t getri_scratchpad_size>(sycl::queue &queue, std::int64_t n, - std::int64_t lda) { - throw unimplemented("lapack", "getri_scratchpad_size"); -} +#define GETRI_LAUNCHER_SCRATCH(TYPE) \ + template <> \ + std::int64_t getri_scratchpad_size(sycl::queue & queue, std::int64_t n, \ + std::int64_t lda) { \ + return lda * n; \ + } + +GETRI_LAUNCHER_SCRATCH(float) +GETRI_LAUNCHER_SCRATCH(double) +GETRI_LAUNCHER_SCRATCH(std::complex) +GETRI_LAUNCHER_SCRATCH(std::complex) + +#undef GETRI_LAUNCHER_SCRATCH // cusolverDnXgetrs does not use scratchpad memory #define GETRS_LAUNCHER_SCRATCH(TYPE) \ @@ -2651,6 +2637,7 @@ inline void heevd_scratchpad_size(const char *func_name, Func func, sycl::queue scratch_size); }); }); + queue.wait(); } #define HEEVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2683,6 +2670,7 @@ inline void hegvd_scratchpad_size(const char *func_name, Func func, sycl::queue lda, nullptr, ldb, nullptr, scratch_size); }); }); + queue.wait(); } #define HEGVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2713,6 +2701,7 @@ inline void hetrd_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, nullptr, nullptr, nullptr, scratch_size); }); }); + queue.wait(); } #define HETRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2753,6 +2742,7 @@ inline void orgbr_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, nullptr, scratch_size); }); }); + queue.wait(); } #define ORGBR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2783,6 +2773,7 @@ inline void orgtr_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, nullptr, scratch_size); }); }); + queue.wait(); } #define ORGTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2812,6 +2803,7 @@ inline void orgqr_scratchpad_size(const char *func_name, Func func, sycl::queue scratch_size); }); }); + queue.wait(); } #define ORGQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2858,6 +2850,7 @@ inline void ormqr_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, ldc, scratch_size); }); }); + queue.wait(); } #define ORMQRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2890,6 +2883,7 @@ inline void ormtr_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); }); + queue.wait(); } #define ORMTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2921,6 +2915,7 @@ inline void potrf_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, scratch_size); }); }); + queue.wait(); } #define POTRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2968,6 +2963,7 @@ inline void potri_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, scratch_size); }); }); + queue.wait(); } #define POTRI_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -2998,6 +2994,7 @@ inline void sytrf_scratchpad_size(const char *func_name, Func func, sycl::queue CUSOLVER_ERROR_FUNC_T(func_name, func, err, handle, n, nullptr, lda, scratch_size); }); }); + queue.wait(); } #define SYTRF_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3030,6 +3027,7 @@ inline void syevd_scratchpad_size(const char *func_name, Func func, sycl::queue scratch_size); }); }); + queue.wait(); } #define SYEVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3062,6 +3060,7 @@ inline void sygvd_scratchpad_size(const char *func_name, Func func, sycl::queue lda, nullptr, ldb, nullptr, scratch_size); }); }); + queue.wait(); } #define SYGVD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3092,6 +3091,7 @@ inline void sytrd_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, nullptr, nullptr, nullptr, scratch_size); }); }); + queue.wait(); } #define SYTRD_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3152,6 +3152,7 @@ inline void ungbr_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, nullptr, scratch_size); }); }); + queue.wait(); } #define UNGBR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3182,6 +3183,7 @@ inline void ungqr_scratchpad_size(const char *func_name, Func func, sycl::queue scratch_size); }); }); + queue.wait(); } #define UNGQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3211,6 +3213,7 @@ inline void ungtr_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, nullptr, scratch_size); }); }); + queue.wait(); } #define UNGTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3259,6 +3262,7 @@ inline void unmqr_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, ldc, scratch_size); }); }); + queue.wait(); } #define UNMQR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ @@ -3291,6 +3295,7 @@ inline void unmtr_scratchpad_size(const char *func_name, Func func, sycl::queue nullptr, lda, nullptr, nullptr, ldc, scratch_size); }); }); + queue.wait(); } #define UNMTR_LAUNCHER_SCRATCH(TYPE, CUSOLVER_ROUTINE) \ diff --git a/src/lapack/backends/cusolver/cusolver_scope_handle.cpp b/src/lapack/backends/cusolver/cusolver_scope_handle.cpp index b2881cdc0..0bc3ebdb0 100644 --- a/src/lapack/backends/cusolver/cusolver_scope_handle.cpp +++ b/src/lapack/backends/cusolver/cusolver_scope_handle.cpp @@ -39,14 +39,15 @@ thread_local cusolver_handle CusolverScopedContextHandler::handle_he cusolver_handle{}; CusolverScopedContextHandler::CusolverScopedContextHandler(sycl::queue queue, - sycl::interop_handler &ih) + sycl::interop_handle &ih) : ih(ih), needToRecover_(false) { - placedContext_ = queue.get_context(); - auto device = queue.get_device(); - auto desired = sycl::get_native(placedContext_); + placedContext_ = new sycl::context(queue.get_context()); + auto cudaDevice = ih.get_native_device(); CUresult err; + CUcontext desired; CUDA_ERROR_FUNC(cuCtxGetCurrent, err, &original_); + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, err, &desired, cudaDevice); if (original_ != desired) { // Sets the desired context as the active one for the thread CUDA_ERROR_FUNC(cuCtxSetCurrent, err, desired); @@ -65,6 +66,7 @@ CusolverScopedContextHandler::~CusolverScopedContextHandler() noexcept(false) { CUresult err; CUDA_ERROR_FUNC(cuCtxSetCurrent, err, original_); } + delete placedContext_; } void ContextCallback(void *userData) { @@ -87,8 +89,11 @@ void ContextCallback(void *userData) { } cusolverDnHandle_t CusolverScopedContextHandler::get_handle(const sycl::queue &queue) { - auto piPlacedContext_ = - reinterpret_cast(sycl::get_native(placedContext_)); + auto cudaDevice = ih.get_native_device(); + CUresult cuErr; + CUcontext desired; + CUDA_ERROR_FUNC(cuDevicePrimaryCtxRetain, cuErr, &desired, cudaDevice); + auto piPlacedContext_ = reinterpret_cast(desired); CUstream streamId = get_stream(queue); cusolverStatus_t err; auto it = handle_helper.cusolver_handle_mapper_.find(piPlacedContext_); @@ -120,14 +125,14 @@ cusolverDnHandle_t CusolverScopedContextHandler::get_handle(const sycl::queue &q auto insert_iter = handle_helper.cusolver_handle_mapper_.insert( std::make_pair(piPlacedContext_, new std::atomic(handle))); - sycl::detail::pi::contextSetExtendedDeleter(placedContext_, ContextCallback, + sycl::detail::pi::contextSetExtendedDeleter(*placedContext_, ContextCallback, insert_iter.first->second); return handle; } CUstream CusolverScopedContextHandler::get_stream(const sycl::queue &queue) { - return sycl::get_native(queue); + return sycl::get_native(queue); } sycl::context CusolverScopedContextHandler::get_context(const sycl::queue &queue) { return queue.get_context(); diff --git a/src/lapack/backends/cusolver/cusolver_scope_handle.hpp b/src/lapack/backends/cusolver/cusolver_scope_handle.hpp index f26a9c449..585b4995a 100644 --- a/src/lapack/backends/cusolver/cusolver_scope_handle.hpp +++ b/src/lapack/backends/cusolver/cusolver_scope_handle.hpp @@ -23,8 +23,10 @@ #else #include #endif -#if __has_include() +#if __has_include() +#if __SYCL_COMPILER_VERSION <= 20220930 #include +#endif #include #include #else @@ -77,15 +79,15 @@ cuSolver handle to the SYCL context. class CusolverScopedContextHandler { CUcontext original_; - sycl::context placedContext_; + sycl::context *placedContext_; bool needToRecover_; - sycl::interop_handler &ih; + sycl::interop_handle &ih; static thread_local cusolver_handle handle_helper; CUstream get_stream(const sycl::queue &queue); sycl::context get_context(const sycl::queue &queue); public: - CusolverScopedContextHandler(sycl::queue queue, sycl::interop_handler &ih); + CusolverScopedContextHandler(sycl::queue queue, sycl::interop_handle &ih); ~CusolverScopedContextHandler() noexcept(false); /** @@ -100,9 +102,13 @@ class CusolverScopedContextHandler { // will be fixed when SYCL-2020 has been implemented for Pi backend. template inline T get_mem(U acc) { - CUdeviceptr cudaPtr = ih.get_mem(acc); + CUdeviceptr cudaPtr = ih.get_native_mem(acc); return reinterpret_cast(cudaPtr); } + + void wait_stream(const sycl::queue &queue) { + cuStreamSynchronize(get_stream(queue)); + } }; } // namespace cusolver diff --git a/src/lapack/backends/cusolver/cusolver_task.hpp b/src/lapack/backends/cusolver/cusolver_task.hpp index 45eb23bbf..9d319be64 100644 --- a/src/lapack/backends/cusolver/cusolver_task.hpp +++ b/src/lapack/backends/cusolver/cusolver_task.hpp @@ -42,9 +42,10 @@ namespace cusolver { template static inline void host_task_internal(H &cgh, sycl::queue queue, F f) { - cgh.interop_task([f, queue](sycl::interop_handler ih) { + cgh.host_task([f, queue](sycl::interop_handle ih) { auto sc = CusolverScopedContextHandler(queue, ih); f(sc); + sc.wait_stream(queue); }); } diff --git a/src/lapack/backends/mklcpu/CMakeLists.txt b/src/lapack/backends/mklcpu/CMakeLists.txt index e26eb0112..fcc60a8e7 100644 --- a/src/lapack/backends/mklcpu/CMakeLists.txt +++ b/src/lapack/backends/mklcpu/CMakeLists.txt @@ -20,25 +20,27 @@ set(LIB_NAME onemkl_lapack_mklcpu) set(LIB_OBJ ${LIB_NAME}_obj) -set(USE_DPCPP_API ON) -find_package(MKL REQUIRED) - add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT mkl_lapack.cpp $<$: lapack_cpu_wrappers.cpp> ) +add_dependencies(onemkl_backend_libs_lapack ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src ${CMAKE_BINARY_DIR}/bin - ${MKL_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) +if(TARGET MKL::MKL_SYCL::LAPACK) + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_SYCL::LAPACK) +else() + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_DPCPP) +endif() set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/src/lapack/backends/mklgpu/CMakeLists.txt b/src/lapack/backends/mklgpu/CMakeLists.txt index 0fac4664a..e11592f82 100644 --- a/src/lapack/backends/mklgpu/CMakeLists.txt +++ b/src/lapack/backends/mklgpu/CMakeLists.txt @@ -20,24 +20,27 @@ set(LIB_NAME onemkl_lapack_mklgpu) set(LIB_OBJ ${LIB_NAME}_obj) -find_package(MKL REQUIRED) - add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT mkl_lapack.cpp $<$: lapack_gpu_wrappers.cpp> ) +add_dependencies(onemkl_backend_libs_lapack ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src ${CMAKE_BINARY_DIR}/bin - ${MKL_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) +if(TARGET MKL::MKL_SYCL::LAPACK) + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_SYCL::LAPACK) +else() + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_DPCPP) +endif() set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/src/lapack/backends/rocsolver/CMakeLists.txt b/src/lapack/backends/rocsolver/CMakeLists.txt index 78841510e..c91089118 100644 --- a/src/lapack/backends/rocsolver/CMakeLists.txt +++ b/src/lapack/backends/rocsolver/CMakeLists.txt @@ -21,22 +21,28 @@ set(LIB_NAME onemkl_lapack_rocsolver) set(LIB_OBJ ${LIB_NAME}_obj) -find_package(rocSOLVER REQUIRED) +find_package(hip REQUIRED) +find_package(rocsolver REQUIRED) +find_package(Threads REQUIRED) + set(SOURCES rocsolver_lapack.cpp rocsolver_batch.cpp $<$:rocsolver_scope_handle.cpp> $<$: rocsolver_wrappers.cpp>) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemkl_backend_libs_lapack ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src/include ${PROJECT_SOURCE_DIR}/src ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} ) target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ONEMKL::rocSOLVER::rocSOLVER) +target_link_libraries(${LIB_OBJ} PRIVATE roc::rocsolver hip::host Threads::Threads) +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL) target_compile_features(${LIB_OBJ} PUBLIC cxx_std_17) set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/src/lapack/backends/rocsolver/rocsolver_helper.hpp b/src/lapack/backends/rocsolver/rocsolver_helper.hpp index b1beff3ca..dade1df64 100644 --- a/src/lapack/backends/rocsolver/rocsolver_helper.hpp +++ b/src/lapack/backends/rocsolver/rocsolver_helper.hpp @@ -27,8 +27,8 @@ #define _ROCSOLVER_HELPER_HPP_ #include -#include -#include +#include +#include #include #include @@ -82,15 +82,7 @@ void overflow_check(Index index, Next... indices) { class rocsolver_error : virtual public std::runtime_error { protected: inline const char *rocsolver_error_map(rocblas_status error) { - switch (error) { - case rocblas_status_success: return "ROCBLAS_STATUS_SUCCESS"; - - case rocblas_status_invalid_value: return "ROCBLAS_STATUS_INVALID_VALUE"; - - case rocblas_status_internal_error: return "ROCBLAS_STATUS_INTERNAL_ERROR"; - - default: return ""; - } + return rocblas_status_to_string(error); } int error_number; ///< Error number @@ -120,16 +112,7 @@ class rocsolver_error : virtual public std::runtime_error { class hip_error : virtual public std::runtime_error { protected: inline const char *hip_error_map(hipError_t result) { - switch (result) { - case HIP_SUCCESS: return "HIP_SUCCESS"; - case hipErrorNotInitialized: return "hipErrorNotInitialized"; - case hipErrorInvalidContext: return "hipErrorInvalidContext"; - case hipErrorInvalidDevice: return "hipErrorInvalidDevice"; - case hipErrorInvalidValue: return "hipErrorInvalidValue"; - case hipErrorMemoryAllocation: return "hipErrorMemoryAllocation"; - case hipErrorLaunchOutOfResources: return "hipErrorLaunchOutOfResources"; - default: return ""; - } + return hipGetErrorName(result); } int error_number; ///< error number public: @@ -271,14 +254,15 @@ inline int get_rocsolver_devinfo(sycl::queue &queue, sycl::buffer &devInfo) inline int get_rocsolver_devinfo(sycl::queue &queue, const int *devInfo) { int dev_info_; - queue.wait(); queue.memcpy(&dev_info_, devInfo, sizeof(int)); + queue.wait(); return dev_info_; } template inline void lapack_info_check(sycl::queue &queue, DEVINFO_T devinfo, const char *func_name, const char *cufunc_name) { + queue.wait(); const int devinfo_ = get_rocsolver_devinfo(queue, devinfo); if (devinfo_ > 0) throw oneapi::mkl::lapack::computation_error( diff --git a/src/lapack/backends/rocsolver/rocsolver_lapack.cpp b/src/lapack/backends/rocsolver/rocsolver_lapack.cpp index 24708ebcf..e5e634ad0 100644 --- a/src/lapack/backends/rocsolver/rocsolver_lapack.cpp +++ b/src/lapack/backends/rocsolver/rocsolver_lapack.cpp @@ -54,8 +54,8 @@ inline void gebrd(const char *func_name, Func func, sycl::queue &queue, std::int auto tauq_ = sc.get_mem(tauq_acc); auto taup_ = sc.get_mem(taup_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, - taup_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, + taup_); }); }); } @@ -112,7 +112,7 @@ inline void geqrf(const char *func_name, Func func, sycl::queue &queue, std::int auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_); }); }); } @@ -146,7 +146,7 @@ void getrf(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, sycl::buffer ipiv32(sycl::range<1>{ ipiv_size }); sycl::buffer devInfo{ 1 }; - queue.submit([&](sycl::handler &cgh) { + auto done = queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto ipiv32_acc = ipiv32.template get_access(cgh); auto devInfo_acc = devInfo.template get_access(cgh); @@ -156,12 +156,14 @@ void getrf(const char *func_name, Func func, sycl::queue &queue, std::int64_t m, auto ipiv32_ = sc.get_mem(ipiv32_acc); auto devInfo_ = sc.get_mem(devInfo_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, ipiv32_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, ipiv32_, + devInfo_); }); }); // Copy from 32-bit buffer to 64-bit queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); auto ipiv32_acc = ipiv32.template get_access(cgh); auto ipiv_acc = ipiv.template get_access(cgh); cgh.parallel_for(sycl::range<1>{ ipiv_size }, [=](sycl::id<1> index) { @@ -240,8 +242,8 @@ inline void getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = sc.get_mem(ipiv_acc); auto b_ = sc.get_mem(b_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_operation(trans), n, - nrhs, a_, lda, ipiv_, b_, ldb); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + n, nrhs, a_, lda, ipiv_, b_, ldb); }); }); } @@ -288,9 +290,10 @@ inline void gesvd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_jobsvd(jobu), - get_rocsolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, - ldvt, scratch_, rocblas_workmode::rocblas_outofplace, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_jobsvd(jobu), + get_rocsolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, + vt_, ldvt, scratch_, rocblas_workmode::rocblas_outofplace, + devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -334,8 +337,9 @@ inline void heevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_job(jobz), - get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_job(jobz), + get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, + devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -378,9 +382,9 @@ inline void hegvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_itype(itype), - get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, - b_, ldb, w_, scratch_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_itype(itype), + get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, + lda, b_, ldb, w_, scratch_, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -408,26 +412,22 @@ inline void hetrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: using rocmDataType_A = typename RocmEquivalentType::Type; using rocmDataType_B = typename RocmEquivalentType::Type; overflow_check(n, lda, scratchpad_size); - sycl::buffer devInfo{ 1 }; queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto d_acc = d.template get_access(cgh); auto e_acc = e.template get_access(cgh); auto tau_acc = tau.template get_access(cgh); - auto devInfo_acc = devInfo.template get_access(cgh); onemkl_rocsolver_host_task(cgh, queue, [=](RocsolverScopedContextHandler &sc) { auto handle = sc.get_handle(queue); auto a_ = sc.get_mem(a_acc); auto d_ = sc.get_mem(d_acc); auto e_ = sc.get_mem(e_acc); auto tau_ = sc.get_mem(tau_acc); - auto devInfo_ = sc.get_mem(devInfo_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, d_, e_, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, d_, e_, tau_); }); }); - lapack_info_check(queue, devInfo, __func__, func_name); } #define HETRD_LAUNCHER(TYPE_A, TYPE_B, ROCSOLVER_ROUTINE) \ @@ -471,8 +471,8 @@ inline void orgbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_generate(vec), m, n, k, - a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_generate(vec), m, + n, k, a_, lda, tau_); }); }); } @@ -504,7 +504,7 @@ inline void orgqr(const char *func_name, Func func, sycl::queue &queue, std::int auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_); }); }); } @@ -536,8 +536,8 @@ inline void orgtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, tau_); }); }); } @@ -573,9 +573,9 @@ inline void ormtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto c_ = sc.get_mem(c_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_side_mode(side), - get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), m, n, - a_, lda, tau_, c_, ldc); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), + m, n, a_, lda, tau_, c_, ldc); }); }); } @@ -625,8 +625,9 @@ inline void ormqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto c_ = sc.get_mem(c_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_side_mode(side), - get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, + ldc); }); }); } @@ -660,8 +661,8 @@ inline void potrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto devInfo_ = sc.get_mem(devInfo_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -696,8 +697,8 @@ inline void potri(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto devInfo_ = sc.get_mem(devInfo_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -732,8 +733,8 @@ inline void potrs(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto b_ = sc.get_mem(b_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, - nrhs, a_, lda, b_, ldb); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, nrhs, a_, lda, b_, ldb); }); }); } @@ -772,8 +773,9 @@ inline void syevd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_job(jobz), - get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_job(jobz), + get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, + devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -814,9 +816,9 @@ inline void sygvd(const char *func_name, Func func, sycl::queue &queue, std::int auto devInfo_ = sc.get_mem(devInfo_acc); auto scratch_ = sc.get_mem(scratch_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_itype(itype), - get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, - b_, ldb, w_, scratch_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_itype(itype), + get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, + lda, b_, ldb, w_, scratch_, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -855,8 +857,8 @@ inline void sytrd(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto e_ = sc.get_mem(e_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, d_, e_, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, d_, e_, tau_); }); }); } @@ -890,7 +892,7 @@ inline void sytrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: std::uint64_t ipiv_size = n; sycl::buffer ipiv32(sycl::range<1>{ ipiv_size }); - queue.submit([&](sycl::handler &cgh) { + auto done = queue.submit([&](sycl::handler &cgh) { auto a_acc = a.template get_access(cgh); auto ipiv32_acc = ipiv32.template get_access(cgh); auto devInfo_acc = devInfo.template get_access(cgh); @@ -900,13 +902,14 @@ inline void sytrf(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto ipiv32_ = sc.get_mem(ipiv32_acc); auto devInfo_ = sc.get_mem(devInfo_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, ipiv32_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, ipiv32_, devInfo_); }); }); // Copy from 32-bit buffer to 64-bit queue.submit([&](sycl::handler &cgh) { + cgh.depends_on(done); auto ipiv32_acc = ipiv32.template get_access(cgh); auto ipiv_acc = ipiv.template get_access(cgh); cgh.parallel_for(sycl::range<1>{ ipiv_size }, [=](sycl::id<1> index) { @@ -973,8 +976,8 @@ inline void ungbr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_generate(vec), m, n, k, - a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_generate(vec), m, + n, k, a_, lda, tau_); }); }); } @@ -1006,7 +1009,7 @@ inline void ungqr(const char *func_name, Func func, sycl::queue &queue, std::int auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_); }); }); } @@ -1038,8 +1041,8 @@ inline void ungtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto a_ = sc.get_mem(a_acc); auto tau_ = sc.get_mem(tau_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, tau_); }); }); } @@ -1089,8 +1092,9 @@ inline void unmqr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto c_ = sc.get_mem(c_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_side_mode(side), - get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, + ldc); }); }); } @@ -1127,9 +1131,9 @@ inline void unmtr(const char *func_name, Func func, sycl::queue &queue, oneapi:: auto tau_ = sc.get_mem(tau_acc); auto c_ = sc.get_mem(c_acc); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_side_mode(side), - get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), m, n, - a_, lda, tau_, c_, ldc); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), + m, n, a_, lda, tau_, c_, ldc); }); }); } @@ -1173,8 +1177,8 @@ inline sycl::event gebrd(const char *func_name, Func func, sycl::queue &queue, s auto tauq_ = reinterpret_cast(tauq); auto taup_ = reinterpret_cast(taup); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, - taup_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, d_, e_, tauq_, + taup_); }); }); return done; @@ -1234,7 +1238,7 @@ inline sycl::event geqrf(const char *func_name, Func func, sycl::queue &queue, s auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, tau_); }); }); return done; @@ -1281,7 +1285,8 @@ inline sycl::event getrf(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto ipiv_ = reinterpret_cast(ipiv32); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, a_, lda, ipiv_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, a_, lda, ipiv_, + devInfo_); }); }); @@ -1293,11 +1298,8 @@ inline sycl::event getrf(const char *func_name, Func func, sycl::queue &queue, s }); }); - queue.wait(); - - free(ipiv32, queue); - lapack_info_check(queue, devInfo, __func__, func_name); + free(ipiv32, queue); free(devInfo, queue); return done_casting; } @@ -1372,8 +1374,8 @@ inline sycl::event getrs(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto b_ = reinterpret_cast(b); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_operation(trans), n, - nrhs, a_, lda, ipiv_, b_, ldb); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_operation(trans), + n, nrhs, a_, lda, ipiv_, b_, ldb); }); }); @@ -1424,9 +1426,10 @@ inline sycl::event gesvd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_jobsvd(jobu), - get_rocsolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, vt_, - ldvt, scratch_, rocblas_workmode::rocblas_outofplace, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_jobsvd(jobu), + get_rocsolver_jobsvd(jobvt), m, n, a_, lda, s_, u_, ldu, + vt_, ldvt, scratch_, rocblas_workmode::rocblas_outofplace, + devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1472,8 +1475,9 @@ inline sycl::event heevd(const char *func_name, Func func, sycl::queue &queue, auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_job(jobz), - get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_job(jobz), + get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, + devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1518,9 +1522,9 @@ inline sycl::event hegvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_itype(itype), - get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, - b_, ldb, w_, scratch_, devInfo); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_itype(itype), + get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, + lda, b_, ldb, w_, scratch_, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1551,7 +1555,6 @@ inline sycl::event hetrd(const char *func_name, Func func, sycl::queue &queue, using rocmDataType_A = typename RocmEquivalentType::Type; using rocmDataType_B = typename RocmEquivalentType::Type; overflow_check(n, lda, scratchpad_size); - int *devInfo = (int *)malloc_device(sizeof(int), queue); auto done = queue.submit([&](sycl::handler &cgh) { int64_t num_events = dependencies.size(); for (int64_t i = 0; i < num_events; i++) { @@ -1563,14 +1566,11 @@ inline sycl::event hetrd(const char *func_name, Func func, sycl::queue &queue, auto d_ = reinterpret_cast(d); auto e_ = reinterpret_cast(e); auto tau_ = reinterpret_cast(tau); - auto devInfo_ = reinterpret_cast(devInfo); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, d_, e_, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, d_, e_, tau_); }); }); - lapack_info_check(queue, devInfo, __func__, func_name); - free(devInfo, queue); return done; } @@ -1619,8 +1619,8 @@ inline sycl::event orgbr(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_generate(vec), m, n, k, - a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_generate(vec), m, + n, k, a_, lda, tau_); }); }); return done; @@ -1657,7 +1657,7 @@ inline sycl::event orgqr(const char *func_name, Func func, sycl::queue &queue, s auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_); }); }); return done; @@ -1693,8 +1693,8 @@ inline sycl::event orgtr(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, tau_); }); }); return done; @@ -1733,9 +1733,9 @@ inline sycl::event ormtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto c_ = reinterpret_cast(c); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_side_mode(side), - get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), m, n, - a_, lda, tau_, c_, ldc); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), + m, n, a_, lda, tau_, c_, ldc); }); }); return done; @@ -1788,8 +1788,9 @@ inline sycl::event ormqr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto c_ = reinterpret_cast(c); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_side_mode(side), - get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, + ldc); }); }); return done; @@ -1828,8 +1829,8 @@ inline sycl::event potrf(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto devInfo_ = reinterpret_cast(devInfo); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1871,8 +1872,8 @@ inline sycl::event potri(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -1913,8 +1914,8 @@ inline sycl::event potrs(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto b_ = reinterpret_cast(b); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, - nrhs, a_, lda, b_, ldb); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, nrhs, a_, lda, b_, ldb); }); }); return done; @@ -1956,8 +1957,9 @@ inline sycl::event syevd(const char *func_name, Func func, sycl::queue &queue, auto scratch_ = reinterpret_cast(scratchpad); auto devInfo_ = reinterpret_cast(devInfo); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_job(jobz), - get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_job(jobz), + get_rocblas_fill_mode(uplo), n, a_, lda, w_, scratch_, + devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -2001,9 +2003,9 @@ inline sycl::event sygvd(const char *func_name, Func func, sycl::queue &queue, s auto devInfo_ = reinterpret_cast(devInfo); auto scratch_ = reinterpret_cast(scratchpad); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocsolver_itype(itype), - get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, lda, - b_, ldb, w_, scratch_, devInfo); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocsolver_itype(itype), + get_rocsolver_job(jobz), get_rocblas_fill_mode(uplo), n, a_, + lda, b_, ldb, w_, scratch_, devInfo_); }); }); lapack_info_check(queue, devInfo, __func__, func_name); @@ -2044,8 +2046,8 @@ inline sycl::event sytrd(const char *func_name, Func func, sycl::queue &queue, auto e_ = reinterpret_cast(e); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, d_, e_, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, d_, e_, tau_); }); }); return done; @@ -2091,8 +2093,8 @@ inline sycl::event sytrf(const char *func_name, Func func, sycl::queue &queue, auto ipiv_ = reinterpret_cast(ipiv32); auto devInfo_ = reinterpret_cast(devInfo); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, ipiv_, devInfo_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, ipiv_, devInfo_); }); }); @@ -2104,11 +2106,8 @@ inline sycl::event sytrf(const char *func_name, Func func, sycl::queue &queue, }); }); - queue.wait(); - - free(ipiv32, queue); - lapack_info_check(queue, devInfo, __func__, func_name); + free(ipiv32, queue); free(devInfo, queue); return done_casting; } @@ -2174,8 +2173,8 @@ inline sycl::event ungbr(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_generate(vec), m, n, k, - a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_generate(vec), m, + n, k, a_, lda, tau_); }); }); return done; @@ -2212,7 +2211,7 @@ inline sycl::event ungqr(const char *func_name, Func func, sycl::queue &queue, s auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, m, n, k, a_, lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, m, n, k, a_, lda, tau_); }); }); return done; @@ -2248,8 +2247,8 @@ inline sycl::event ungtr(const char *func_name, Func func, sycl::queue &queue, auto a_ = reinterpret_cast(a); auto tau_ = reinterpret_cast(tau); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_fill_mode(uplo), n, a_, - lda, tau_); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_fill_mode(uplo), + n, a_, lda, tau_); }); }); return done; @@ -2302,8 +2301,9 @@ inline sycl::event unmqr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto c_ = reinterpret_cast(c); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_side_mode(side), - get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, ldc); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + get_rocblas_operation(trans), m, n, k, a_, lda, tau_, c_, + ldc); }); }); return done; @@ -2344,9 +2344,9 @@ inline sycl::event unmtr(const char *func_name, Func func, sycl::queue &queue, auto tau_ = reinterpret_cast(tau); auto c_ = reinterpret_cast(c); rocblas_status err; - ROCSOLVER_ERROR_FUNC_T(func_name, func, err, handle, get_rocblas_side_mode(side), - get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), m, n, - a_, lda, tau_, c_, ldc); + ROCSOLVER_ERROR_FUNC_T_SYNC(func_name, func, err, handle, get_rocblas_side_mode(side), + get_rocblas_fill_mode(uplo), get_rocblas_operation(trans), + m, n, a_, lda, tau_, c_, ldc); }); }); return done; diff --git a/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp b/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp index 78935f979..42e262e7b 100644 --- a/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp +++ b/src/lapack/backends/rocsolver/rocsolver_scope_handle.cpp @@ -45,10 +45,11 @@ RocsolverScopedContextHandler::RocsolverScopedContextHandler(sycl::queue queue, : ih(ih), needToRecover_(false) { placedContext_ = new sycl::context(queue.get_context()); - auto device = queue.get_device(); - auto desired = sycl::get_native(*placedContext_); + auto hipDevice = ih.get_native_device(); hipError_t err; + hipCtx_t desired; HIP_ERROR_FUNC(hipCtxGetCurrent, err, &original_); + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, err, &desired, hipDevice); if (original_ != desired) { // Sets the desired context as the active one for the thread HIP_ERROR_FUNC(hipCtxSetCurrent, err, desired); @@ -90,8 +91,11 @@ void ContextCallback(void *userData) { } rocblas_handle RocsolverScopedContextHandler::get_handle(const sycl::queue &queue) { - auto piPlacedContext_ = reinterpret_cast( - sycl::get_native(*placedContext_)); + auto hipDevice = ih.get_native_device(); + hipError_t hipErr; + hipCtx_t desired; + HIP_ERROR_FUNC(hipDevicePrimaryCtxRetain, hipErr, &desired, hipDevice); + auto piPlacedContext_ = reinterpret_cast(desired); hipStream_t streamId = get_stream(queue); rocblas_status err; auto it = handle_helper.rocsolver_handle_mapper_.find(piPlacedContext_); diff --git a/src/lapack/backends/rocsolver/rocsolver_scope_handle.hpp b/src/lapack/backends/rocsolver/rocsolver_scope_handle.hpp index e87abae5c..9f1bc068a 100644 --- a/src/lapack/backends/rocsolver/rocsolver_scope_handle.hpp +++ b/src/lapack/backends/rocsolver/rocsolver_scope_handle.hpp @@ -57,7 +57,7 @@ class RocsolverScopedContextHandler { // will be fixed when SYCL-2020 has been implemented for Pi backend. template inline T get_mem(U acc) { - hipDeviceptr_t hipPtr = ih.get_native_mem(acc); + hipDeviceptr_t hipPtr = ih.get_native_mem(acc); return reinterpret_cast(hipPtr); } }; diff --git a/src/lapack/backends/rocsolver/rocsolver_task.hpp b/src/lapack/backends/rocsolver/rocsolver_task.hpp index fae2ba893..08f8e5cea 100644 --- a/src/lapack/backends/rocsolver/rocsolver_task.hpp +++ b/src/lapack/backends/rocsolver/rocsolver_task.hpp @@ -22,8 +22,8 @@ #ifndef _MKL_LAPACK_ROCSOLVER_TASK_HPP_ #define _MKL_LAPACK_ROCSOLVER_TASK_HPP_ #include -#include -#include +#include +#include #include #if __has_include() #include diff --git a/src/rng/CMakeLists.txt b/src/rng/CMakeLists.txt index 48e8909a6..30df39403 100644 --- a/src/rng/CMakeLists.txt +++ b/src/rng/CMakeLists.txt @@ -29,6 +29,7 @@ target_include_directories(onemkl_rng ${PROJECT_SOURCE_DIR}/src ${PROJECT_SOURCE_DIR}/src/include ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} $ ) diff --git a/src/rng/backends/CMakeLists.txt b/src/rng/backends/CMakeLists.txt index 7324a6103..9045f7e75 100644 --- a/src/rng/backends/CMakeLists.txt +++ b/src/rng/backends/CMakeLists.txt @@ -17,6 +17,9 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== +add_custom_target(onemkl_backend_libs_rng) +add_dependencies(onemkl_backend_libs onemkl_backend_libs_rng) + if(ENABLE_MKLCPU_BACKEND) add_subdirectory(mklcpu) endif() diff --git a/src/rng/backends/curand/CMakeLists.txt b/src/rng/backends/curand/CMakeLists.txt index 41c2da185..f37a34f1d 100644 --- a/src/rng/backends/curand/CMakeLists.txt +++ b/src/rng/backends/curand/CMakeLists.txt @@ -66,12 +66,14 @@ set(SOURCES philox4x32x10.cpp add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemkl_backend_libs_rng ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src ${CMAKE_BINARY_DIR}/bin ${MKL_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ONEMKL::cuRAND::cuRAND) diff --git a/src/rng/backends/curand/curand_helper.hpp b/src/rng/backends/curand/curand_helper.hpp index f3feff3a8..3926e6283 100644 --- a/src/rng/backends/curand/curand_helper.hpp +++ b/src/rng/backends/curand/curand_helper.hpp @@ -211,22 +211,22 @@ static inline void range_transform_fp(sycl::queue& queue, T a, T b, std::int64_t sycl::buffer& r) { queue.submit([&](sycl::handler& cgh) { auto acc = r.template get_access(cgh); - cgh.parallel_for(sycl::range<1>(n), - [=](sycl::id<1> id) { acc[id] = acc[id] * (b - a) + a; }); + cgh.parallel_for(n, [=](sycl::id<1> id) { acc[id] = acc[id] * (b - a) + a; }); }); } + template -static inline sycl::event range_transform_fp(sycl::queue& queue, T a, T b, std::int64_t n, T* r) { - return queue.submit([&](sycl::handler& cgh) { - cgh.parallel_for(sycl::range<1>(n), [=](sycl::id<1> id) { r[id] = r[id] * (b - a) + a; }); - }); +static inline sycl::event range_transform_fp(sycl::queue& queue, T a, T b, std::int64_t n, T* r, + sycl::event dependency) { + return queue.parallel_for(n, dependency, [=](sycl::id<1> id) { r[id] = r[id] * (b - a) + a; }); } + template static inline void range_transform_fp_accurate(sycl::queue& queue, T a, T b, std::int64_t n, sycl::buffer& r) { queue.submit([&](sycl::handler& cgh) { auto acc = r.template get_access(cgh); - cgh.parallel_for(sycl::range<1>(n), [=](sycl::id<1> id) { + cgh.parallel_for(n, [=](sycl::id<1> id) { acc[id] = acc[id] * (b - a) + a; if (acc[id] < a) { acc[id] = a; @@ -237,19 +237,18 @@ static inline void range_transform_fp_accurate(sycl::queue& queue, T a, T b, std }); }); } + template static inline sycl::event range_transform_fp_accurate(sycl::queue& queue, T a, T b, std::int64_t n, - T* r) { - return queue.submit([&](sycl::handler& cgh) { - cgh.parallel_for(sycl::range<1>(n), [=](sycl::id<1> id) { - r[id] = r[id] * (b - a) + a; - if (r[id] < a) { - r[id] = a; - } - else if (r[id] > b) { - r[id] = b; - } - }); + T* r, sycl::event dependency) { + return queue.parallel_for(n, dependency, [=](sycl::id<1> id) { + r[id] = r[id] * (b - a) + a; + if (r[id] < a) { + r[id] = a; + } + else if (r[id] > b) { + r[id] = b; + } }); } @@ -275,17 +274,15 @@ inline void range_transform_int(sycl::queue& queue, T a, T b, std::int64_t n, queue.submit([&](sycl::handler& cgh) { auto acc_in = in.template get_access(cgh); auto acc_out = out.template get_access(cgh); - cgh.parallel_for(sycl::range<1>(n), - [=](sycl::id<1> id) { acc_out[id] = a + acc_in[id] % (b - a); }); + cgh.parallel_for(n, [=](sycl::id<1> id) { acc_out[id] = a + acc_in[id] % (b - a); }); }); } + template inline sycl::event range_transform_int(sycl::queue& queue, T a, T b, std::int64_t n, - std::uint32_t* in, T* out) { - return queue.submit([&](sycl::handler& cgh) { - cgh.parallel_for(sycl::range<1>(n), - [=](sycl::id<1> id) { out[id] = a + in[id] % (b - a); }); - }); + std::uint32_t* in, T* out, sycl::event dependency) { + return queue.parallel_for(n, dependency, + [=](sycl::id<1> id) { out[id] = a + in[id] % (b - a); }); } // Static template functions oneapi::mkl::rng::curand::sample_bernoulli for @@ -311,15 +308,14 @@ static inline void sample_bernoulli_from_uniform(sycl::queue& queue, float p, st queue.submit([&](sycl::handler& cgh) { auto acc_in = in.template get_access(cgh); auto acc_out = out.template get_access(cgh); - cgh.parallel_for(sycl::range<1>(n), [=](sycl::id<1> id) { acc_out[id] = acc_in[id] < p; }); + cgh.parallel_for(n, [=](sycl::id<1> id) { acc_out[id] = acc_in[id] < p; }); }); } + template static inline sycl::event sample_bernoulli_from_uniform(sycl::queue& queue, float p, std::int64_t n, float* in, T* out) { - return queue.submit([&](sycl::handler& cgh) { - cgh.parallel_for(sycl::range<1>(n), [=](sycl::id<1> id) { out[id] = in[id] < p; }); - }); + return queue.parallel_for(n, [=](sycl::id<1> id) { out[id] = in[id] < p; }); } } // namespace curand diff --git a/src/rng/backends/curand/mrg32k3a.cpp b/src/rng/backends/curand/mrg32k3a.cpp index 4f0622fc5..dd44f4def 100644 --- a/src/rng/backends/curand/mrg32k3a.cpp +++ b/src/rng/backends/curand/mrg32k3a.cpp @@ -62,8 +62,10 @@ #include #endif #ifndef __HIPSYCL__ -#if __has_include() +#if __has_include() +#if __SYCL_COMPILER_VERSION <= 20220930 #include +#endif #else #include #endif @@ -108,107 +110,93 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { virtual void generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniform, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); - range_transform_fp(queue_, distr.a(), distr.b(), n, r); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniform, status, engine_, r_ptr, n); + }); + }); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } virtual void generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniformDouble, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); - range_transform_fp(queue_, distr.a(), distr.b(), n, r); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniformDouble, status, engine_, r_ptr, n); + }); + }); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } virtual void generate(const oneapi::mkl::rng::uniform< std::int32_t, oneapi::mkl::rng::uniform_method::standard>& distr, std::int64_t n, sycl::buffer& r) override { sycl::buffer ib(n); - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = ib.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerate, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); - range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); + queue_.submit([&](sycl::handler& cgh) { + auto acc = ib.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerate, status, engine_, r_ptr, n); + }); + }); + range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } virtual void generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniform, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniform, status, engine_, r_ptr, n); + }); + }); range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } virtual void generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniformDouble, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniformDouble, status, engine_, r_ptr, n); + }); + }); range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } virtual void generate(const oneapi::mkl::rng::gaussian< float, oneapi::mkl::rng::gaussian_method::box_muller2>& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateNormal, status, engine_, r_ptr, n, distr.mean(), - distr.stddev()); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateNormal, status, engine_, r_ptr, n, distr.mean(), + distr.stddev()); + }); + }); } virtual void generate(const oneapi::mkl::rng::gaussian< double, oneapi::mkl::rng::gaussian_method::box_muller2>& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateNormalDouble, status, engine_, r_ptr, n, distr.mean(), - distr.stddev()); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateNormalDouble, status, engine_, r_ptr, n, distr.mean(), + distr.stddev()); + }); + }); } virtual void generate( @@ -230,31 +218,27 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { virtual void generate(const oneapi::mkl::rng::lognormal< float, oneapi::mkl::rng::lognormal_method::box_muller2>& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateLogNormal, status, engine_, r_ptr, n, distr.m(), - distr.s()); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateLogNormal, status, engine_, r_ptr, n, distr.m(), + distr.s()); + }); + }); } virtual void generate(const oneapi::mkl::rng::lognormal< double, oneapi::mkl::rng::lognormal_method::box_muller2>& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateLogNormalDouble, status, engine_, r_ptr, n, distr.m(), - distr.s()); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateLogNormalDouble, status, engine_, r_ptr, n, distr.m(), + distr.s()); + }); + }); } virtual void generate( @@ -303,15 +287,13 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { virtual void generate(const bits& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.template get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerate, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.template get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerate, status, engine_, r_ptr, n); + }); + }); } // USM APIs @@ -320,77 +302,76 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { const oneapi::mkl::rng::uniform& distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniform, status, engine_, r, n); - }); - }) - .wait_and_throw(); - return range_transform_fp(queue_, distr.a(), distr.b(), n, r); + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniform, status, engine_, r, n); + }); + }); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r, generate_event); } virtual sycl::event generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniformDouble, status, engine_, r, n); - }); - }) - .wait_and_throw(); - return range_transform_fp(queue_, distr.a(), distr.b(), n, r); + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniformDouble, status, engine_, r, n); + }); + }); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r, generate_event); } virtual sycl::event generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, std::int32_t* r, const std::vector& dependencies) override { - std::uint32_t* ib = (std::uint32_t*)malloc_device( - n * sizeof(std::uint32_t), queue_.get_device(), queue_.get_context()); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerate, status, engine_, ib, n); - }); - }) + auto usm_deleter = [this](std::uint32_t* ptr) { + sycl::free(ptr, this->queue_); + }; + std::unique_ptr usm_ib( + sycl::malloc_device(n, queue_), usm_deleter); + std::uint32_t* ib = usm_ib.get(); + sycl::event::wait_and_throw(dependencies); + + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerate, status, engine_, ib, n); + }); + }); + range_transform_int(queue_, distr.a(), distr.b(), n, ib, r, generate_event) .wait_and_throw(); - return range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); + return sycl::event{}; } virtual sycl::event generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniform, status, engine_, r, n); - }); - }) - .wait_and_throw(); - return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniform, status, engine_, r, n); + }); + }); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r, + generate_event); } virtual sycl::event generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniformDouble, status, engine_, r, n); - }); - }) - .wait_and_throw(); - return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniformDouble, status, engine_, r, n); + }); + }); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r, generate_event); } virtual sycl::event generate( diff --git a/src/rng/backends/curand/philox4x32x10.cpp b/src/rng/backends/curand/philox4x32x10.cpp index 7f3981067..c3d4393d2 100644 --- a/src/rng/backends/curand/philox4x32x10.cpp +++ b/src/rng/backends/curand/philox4x32x10.cpp @@ -62,8 +62,10 @@ #include #endif #ifndef __HIPSYCL__ -#if __has_include() +#if __has_include() +#if __SYCL_COMPILER_VERSION <= 20220930 #include +#endif #else #include #endif @@ -131,107 +133,93 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { virtual inline void generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniform, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); - range_transform_fp(queue_, distr.a(), distr.b(), n, r); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniform, status, engine_, r_ptr, n); + }); + }); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } virtual void generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniformDouble, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); - range_transform_fp(queue_, distr.a(), distr.b(), n, r); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniformDouble, status, engine_, r_ptr, n); + }); + }); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } virtual void generate(const oneapi::mkl::rng::uniform< std::int32_t, oneapi::mkl::rng::uniform_method::standard>& distr, std::int64_t n, sycl::buffer& r) override { sycl::buffer ib(n); - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = ib.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerate, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); - range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); + queue_.submit([&](sycl::handler& cgh) { + auto acc = ib.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerate, status, engine_, r_ptr, n); + }); + }); + range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } virtual void generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniform, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); - range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniform, status, engine_, r_ptr, n); + }); + }); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } virtual void generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniformDouble, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); - range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniformDouble, status, engine_, r_ptr, n); + }); + }); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } virtual void generate(const oneapi::mkl::rng::gaussian< float, oneapi::mkl::rng::gaussian_method::box_muller2>& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateNormal, status, engine_, r_ptr, n, distr.mean(), - distr.stddev()); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateNormal, status, engine_, r_ptr, n, distr.mean(), + distr.stddev()); + }); + }); } virtual void generate(const oneapi::mkl::rng::gaussian< double, oneapi::mkl::rng::gaussian_method::box_muller2>& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateNormalDouble, status, engine_, r_ptr, n, distr.mean(), - distr.stddev()); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateNormalDouble, status, engine_, r_ptr, n, distr.mean(), + distr.stddev()); + }); + }); } virtual void generate( @@ -253,31 +241,27 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { virtual void generate(const oneapi::mkl::rng::lognormal< float, oneapi::mkl::rng::lognormal_method::box_muller2>& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateLogNormal, status, engine_, r_ptr, n, distr.m(), - distr.s()); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateLogNormal, status, engine_, r_ptr, n, distr.m(), + distr.s()); + }); + }); } virtual void generate(const oneapi::mkl::rng::lognormal< double, oneapi::mkl::rng::lognormal_method::box_muller2>& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerateLogNormalDouble, status, engine_, r_ptr, n, distr.m(), - distr.s()); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerateLogNormalDouble, status, engine_, r_ptr, n, distr.m(), + distr.s()); + }); + }); } virtual void generate( @@ -326,15 +310,13 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { virtual void generate(const bits& distr, std::int64_t n, sycl::buffer& r) override { - queue_ - .submit([&](sycl::handler& cgh) { - auto acc = r.template get_access(cgh); - onemkl_curand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { - curandStatus_t status; - CURAND_CALL(curandGenerate, status, engine_, r_ptr, n); - }); - }) - .wait_and_throw(); + queue_.submit([&](sycl::handler& cgh) { + auto acc = r.template get_access(cgh); + onemkl_curand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { + curandStatus_t status; + CURAND_CALL(curandGenerate, status, engine_, r_ptr, n); + }); + }); } // USM APIs @@ -343,77 +325,75 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { const oneapi::mkl::rng::uniform& distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniform, status, engine_, r, n); - }); - }) - .wait_and_throw(); - return range_transform_fp(queue_, distr.a(), distr.b(), n, r); + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniform, status, engine_, r, n); + }); + }); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r, generate_event); } virtual sycl::event generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniformDouble, status, engine_, r, n); - }); - }) - .wait_and_throw(); - return range_transform_fp(queue_, distr.a(), distr.b(), n, r); + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniformDouble, status, engine_, r, n); + }); + }); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r, generate_event); } virtual sycl::event generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, std::int32_t* r, const std::vector& dependencies) override { - std::uint32_t* ib = (std::uint32_t*)malloc_device( - n * sizeof(std::uint32_t), queue_.get_device(), queue_.get_context()); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerate, status, engine_, ib, n); - }); - }) + auto usm_deleter = [this](std::uint32_t* ptr) { + sycl::free(ptr, this->queue_); + }; + std::unique_ptr usm_ib( + sycl::malloc_device(n, queue_), usm_deleter); + std::uint32_t* ib = usm_ib.get(); + sycl::event::wait_and_throw(dependencies); + + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerate, status, engine_, ib, n); + }); + }); + range_transform_int(queue_, distr.a(), distr.b(), n, ib, r, generate_event) .wait_and_throw(); - return range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); + return sycl::event{}; } virtual sycl::event generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniform, status, engine_, r, n); - }); - }) - .wait_and_throw(); - return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniform, status, engine_, r, n); + }); + }); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r, generate_event); } virtual sycl::event generate( const oneapi::mkl::rng::uniform& distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - queue_ - .submit([&](sycl::handler& cgh) { - onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { - curandStatus_t status; - CURAND_CALL(curandGenerateUniformDouble, status, engine_, r, n); - }); - }) - .wait_and_throw(); - return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); + sycl::event generate_event = queue_.submit([&](sycl::handler& cgh) { + onemkl_curand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + curandStatus_t status; + CURAND_CALL(curandGenerateUniformDouble, status, engine_, r, n); + }); + }); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r, generate_event); } virtual sycl::event generate( diff --git a/src/rng/backends/mklcpu/CMakeLists.txt b/src/rng/backends/mklcpu/CMakeLists.txt index 984eeabe3..e72ce048f 100644 --- a/src/rng/backends/mklcpu/CMakeLists.txt +++ b/src/rng/backends/mklcpu/CMakeLists.txt @@ -20,8 +20,6 @@ set(LIB_NAME onemkl_rng_mklcpu) set(LIB_OBJ ${LIB_NAME}_obj) -find_package(MKL REQUIRED) - set(SOURCES cpu_common.hpp philox4x32x10.cpp mrg32k3a.cpp @@ -30,18 +28,19 @@ set(SOURCES cpu_common.hpp add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemkl_backend_libs_rng ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src ${CMAKE_BINARY_DIR}/bin - ${MKL_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) add_sycl_to_target(TARGET ${LIB_OBJ} SOURCES ${SOURCES}) endif() -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_C}) +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL) set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/src/rng/backends/mklcpu/cpu_common.hpp b/src/rng/backends/mklcpu/cpu_common.hpp index de9d52772..cbd6cae59 100644 --- a/src/rng/backends/mklcpu/cpu_common.hpp +++ b/src/rng/backends/mklcpu/cpu_common.hpp @@ -56,6 +56,16 @@ class kernel_name {}; template class kernel_name_usm {}; +template +typename Acc::value_type *get_raw_ptr(Acc acc) { +// Workaround for AdaptiveCPP, as they do not yet support the get_multi_ptr function +#ifndef __HIPSYCL__ + return acc.template get_multi_ptr().get_raw(); +#else + return acc.get_pointer(); +#endif +} + } // namespace mklcpu } // namespace rng } // namespace mkl diff --git a/src/rng/backends/mklcpu/mrg32k3a.cpp b/src/rng/backends/mklcpu/mrg32k3a.cpp index c1b4ef5a4..cc234de45 100644 --- a/src/rng/backends/mklcpu/mrg32k3a.cpp +++ b/src/rng/backends/mklcpu/mrg32k3a.cpp @@ -67,8 +67,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -81,8 +81,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -95,8 +95,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -109,8 +109,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -123,8 +123,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -137,8 +137,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.mean(), distr.stddev()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.mean(), distr.stddev()); }); }); } @@ -151,8 +151,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.mean(), distr.stddev()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.mean(), distr.stddev()); }); }); } @@ -165,8 +165,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.mean(), distr.stddev()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.mean(), distr.stddev()); }); }); } @@ -179,8 +179,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.mean(), distr.stddev()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.mean(), distr.stddev()); }); }); } @@ -193,8 +193,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngLognormal(VSL_RNG_METHOD_LOGNORMAL_BOXMULLER2, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.m(), distr.s(), distr.displ(), + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.m(), distr.s(), distr.displ(), distr.scale()); }); }); @@ -208,8 +208,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngLognormal(VSL_RNG_METHOD_LOGNORMAL_BOXMULLER2, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.m(), distr.s(), distr.displ(), + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.m(), distr.s(), distr.displ(), distr.scale()); }); }); @@ -223,8 +223,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngLognormal(VSL_RNG_METHOD_LOGNORMAL_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.m(), distr.s(), distr.displ(), + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.m(), distr.s(), distr.displ(), distr.scale()); }); }); @@ -238,8 +238,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngLognormal(VSL_RNG_METHOD_LOGNORMAL_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.m(), distr.s(), distr.displ(), + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.m(), distr.s(), distr.displ(), distr.scale()); }); }); @@ -253,8 +253,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.p()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.p()); }); }); } @@ -266,9 +266,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_stream = stream_buf.get_access(cgh); auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { - std::uint32_t* r_ptr = acc_r.get_pointer(); + std::uint32_t* r_ptr = get_raw_ptr(acc_r); viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, - static_cast(acc_stream.get_pointer()), n, + static_cast(get_raw_ptr(acc_stream)), n, reinterpret_cast(r_ptr), distr.p()); }); }); @@ -282,8 +282,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { viRngPoisson(VSL_RNG_METHOD_POISSON_POISNORM, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.lambda()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.lambda()); }); }); } @@ -295,9 +295,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_stream = stream_buf.get_access(cgh); auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { - std::uint32_t* r_ptr = acc_r.get_pointer(); + std::uint32_t* r_ptr = get_raw_ptr(acc_r); viRngPoisson(VSL_RNG_METHOD_POISSON_POISNORM, - static_cast(acc_stream.get_pointer()), n, + static_cast(get_raw_ptr(acc_stream)), n, reinterpret_cast(r_ptr), distr.lambda()); }); }); @@ -311,8 +311,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { viRngUniformBits(VSL_RNG_METHOD_UNIFORMBITS_STD, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r)); }); }); } diff --git a/src/rng/backends/mklcpu/philox4x32x10.cpp b/src/rng/backends/mklcpu/philox4x32x10.cpp index c687df3f5..3f8e5e89b 100644 --- a/src/rng/backends/mklcpu/philox4x32x10.cpp +++ b/src/rng/backends/mklcpu/philox4x32x10.cpp @@ -69,8 +69,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -83,8 +83,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -97,8 +97,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { viRngUniform(VSL_RNG_METHOD_UNIFORM_STD, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -111,8 +111,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -125,8 +125,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngUniform(VSL_RNG_METHOD_UNIFORM_STD_ACCURATE, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.a(), distr.b()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.a(), distr.b()); }); }); } @@ -139,8 +139,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.mean(), distr.stddev()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.mean(), distr.stddev()); }); }); } @@ -153,8 +153,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_BOXMULLER2, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.mean(), distr.stddev()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.mean(), distr.stddev()); }); }); } @@ -167,8 +167,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.mean(), distr.stddev()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.mean(), distr.stddev()); }); }); } @@ -181,8 +181,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngGaussian(VSL_RNG_METHOD_GAUSSIAN_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.mean(), distr.stddev()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.mean(), distr.stddev()); }); }); } @@ -195,8 +195,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngLognormal(VSL_RNG_METHOD_LOGNORMAL_BOXMULLER2, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.m(), distr.s(), distr.displ(), + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.m(), distr.s(), distr.displ(), distr.scale()); }); }); @@ -210,8 +210,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngLognormal(VSL_RNG_METHOD_LOGNORMAL_BOXMULLER2, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.m(), distr.s(), distr.displ(), + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.m(), distr.s(), distr.displ(), distr.scale()); }); }); @@ -225,8 +225,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vsRngLognormal(VSL_RNG_METHOD_LOGNORMAL_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.m(), distr.s(), distr.displ(), + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.m(), distr.s(), distr.displ(), distr.scale()); }); }); @@ -240,8 +240,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { vdRngLognormal(VSL_RNG_METHOD_LOGNORMAL_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.m(), distr.s(), distr.displ(), + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.m(), distr.s(), distr.displ(), distr.scale()); }); }); @@ -255,8 +255,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.p()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.p()); }); }); } @@ -268,9 +268,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_stream = stream_buf.get_access(cgh); auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { - std::uint32_t* r_ptr = acc_r.get_pointer(); + std::uint32_t* r_ptr = get_raw_ptr(acc_r); viRngBernoulli(VSL_RNG_METHOD_BERNOULLI_ICDF, - static_cast(acc_stream.get_pointer()), n, + static_cast(get_raw_ptr(acc_stream)), n, reinterpret_cast(r_ptr), distr.p()); }); }); @@ -284,8 +284,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { viRngPoisson(VSL_RNG_METHOD_POISSON_POISNORM, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer(), distr.lambda()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r), distr.lambda()); }); }); } @@ -297,9 +297,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_stream = stream_buf.get_access(cgh); auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { - std::uint32_t* r_ptr = acc_r.get_pointer(); + std::uint32_t* r_ptr = get_raw_ptr(acc_r); viRngPoisson(VSL_RNG_METHOD_POISSON_POISNORM, - static_cast(acc_stream.get_pointer()), n, + static_cast(get_raw_ptr(acc_stream)), n, reinterpret_cast(r_ptr), distr.lambda()); }); }); @@ -313,8 +313,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { auto acc_r = r.get_access(cgh); host_task>(cgh, [=]() { viRngUniformBits(VSL_RNG_METHOD_UNIFORMBITS_STD, - static_cast(acc_stream.get_pointer()), n, - acc_r.get_pointer()); + static_cast(get_raw_ptr(acc_stream)), n, + get_raw_ptr(acc_r)); }); }); } diff --git a/src/rng/backends/mklgpu/CMakeLists.txt b/src/rng/backends/mklgpu/CMakeLists.txt index 99a08302e..150f90136 100644 --- a/src/rng/backends/mklgpu/CMakeLists.txt +++ b/src/rng/backends/mklgpu/CMakeLists.txt @@ -19,7 +19,6 @@ set(LIB_NAME onemkl_rng_mklgpu) set(LIB_OBJ ${LIB_NAME}_obj) -find_package(MKL REQUIRED) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT @@ -28,17 +27,22 @@ add_library(${LIB_OBJ} OBJECT mrg32k3a.cpp $<$: mkl_rng_gpu_wrappers.cpp> ) +add_dependencies(onemkl_backend_libs_rng ${LIB_NAME}) target_include_directories(${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src ${CMAKE_BINARY_DIR}/bin - ${MKL_INCLUDE} + ${ONEMKL_GENERATED_INCLUDE_PATH} ) -target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT} ${MKL_COPT}) +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL ${MKL_LINK_SYCL}) +if(TARGET MKL::MKL_SYCL::RNG) + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_SYCL::RNG) +else() + target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL MKL::MKL_DPCPP) +endif() set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON diff --git a/src/rng/backends/rocrand/CMakeLists.txt b/src/rng/backends/rocrand/CMakeLists.txt index c3c98fd4b..47929703b 100644 --- a/src/rng/backends/rocrand/CMakeLists.txt +++ b/src/rng/backends/rocrand/CMakeLists.txt @@ -54,20 +54,23 @@ set(LIB_NAME onemkl_rng_rocrand) set(LIB_OBJ ${LIB_NAME}_obj) -find_package(rocRAND REQUIRED) +find_package(hip REQUIRED) +find_package(rocrand REQUIRED) +find_package(Threads REQUIRED) set(SOURCES philox4x32x10.cpp mrg32k3a.cpp $<$: mkl_rng_rocrand_wrappers.cpp>) add_library(${LIB_NAME}) add_library(${LIB_OBJ} OBJECT ${SOURCES}) +add_dependencies(onemkl_backend_libs_rng ${LIB_NAME}) target_include_directories( ${LIB_OBJ} PRIVATE ${PROJECT_SOURCE_DIR}/include ${PROJECT_SOURCE_DIR}/src - ${CMAKE_BINARY_DIR}/bin ${MKL_INCLUDE}) + ${CMAKE_BINARY_DIR}/bin ${MKL_INCLUDE} ${ONEMKL_GENERATED_INCLUDE_PATH}) -target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL - ONEMKL::rocRAND::rocRAND) +target_link_libraries(${LIB_OBJ} PRIVATE roc::rocrand hip::host Threads::Threads) +target_link_libraries(${LIB_OBJ} PUBLIC ONEMKL::SYCL::SYCL) target_compile_features(${LIB_OBJ} PUBLIC cxx_std_11) set_target_properties(${LIB_OBJ} PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/src/rng/backends/rocrand/mrg32k3a.cpp b/src/rng/backends/rocrand/mrg32k3a.cpp index 1709bd6c7..424f14caf 100644 --- a/src/rng/backends/rocrand/mrg32k3a.cpp +++ b/src/rng/backends/rocrand/mrg32k3a.cpp @@ -1,7 +1,7 @@ /******************************************************************************* * Copyright (C) 2022 Heidelberg University, Engineering Mathematics and Computing Lab (EMCL) * and Computing Centre (URZ) - * cuRAND back-end Copyright (c) 2021, The Regents of the University of + * rocRAND back-end Copyright (c) 2021, The Regents of the University of * California, through Lawrence Berkeley National Laboratory (subject to receipt * of any required approvals from the U.S. Dept. of Energy). All rights * reserved. @@ -88,7 +88,9 @@ namespace rocrand { class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { public: mrg32k3a_impl(sycl::queue queue, std::uint32_t seed) - : oneapi::mkl::rng::detail::engine_impl(queue) { + : oneapi::mkl::rng::detail::engine_impl(queue), + seed_(seed), + offset_(0) { rocrand_status status; ROCRAND_CALL(rocrand_create_generator, status, &engine_, ROCRAND_RNG_PSEUDO_MRG32K3A); ROCRAND_CALL(rocrand_set_seed, status, engine_, (unsigned long long)seed); @@ -97,12 +99,19 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { mrg32k3a_impl(sycl::queue queue, std::initializer_list seed) : oneapi::mkl::rng::detail::engine_impl(queue) { throw oneapi::mkl::unimplemented("rng", "mrg32ka engine", - "multi-seed unsupported by cuRAND backend"); + "multi-seed unsupported by rocRAND backend"); } - mrg32k3a_impl(const mrg32k3a_impl* other) : oneapi::mkl::rng::detail::engine_impl(*other) { - throw oneapi::mkl::unimplemented("rng", "mrg32ka engine", - "copy construction unsupported by cuRAND backend"); + mrg32k3a_impl(const mrg32k3a_impl* other) + : oneapi::mkl::rng::detail::engine_impl(*other), + seed_(other->seed_), + offset_(other->offset_) { + rocrand_status status; + ROCRAND_CALL(rocrand_create_generator, status, &engine_, ROCRAND_RNG_PSEUDO_MRG32K3A); + ROCRAND_CALL(rocrand_set_seed, status, engine_, (unsigned long long)seed_); + + // Allign this->engine_'s offset state with other->engine_'s offset + skip_ahead(offset_); } // Buffers API @@ -119,6 +128,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -134,6 +146,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -150,6 +165,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } @@ -165,6 +183,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -180,6 +201,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -196,6 +220,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::gaussian< @@ -211,22 +237,42 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal, status, engine_, r_ptr, n, distr.mean(), + distr.stddev()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r_ptr, n, + distr.mean(), distr.stddev()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::lognormal< @@ -242,6 +288,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::lognormal< @@ -257,50 +305,88 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r_ptr, n, distr.m(), + distr.s()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r_ptr, n, + distr.m(), distr.s()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const bernoulli& distr, std::int64_t n, sycl::buffer& r) override { throw oneapi::mkl::unimplemented( "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); } virtual void generate(const bernoulli& distr, std::int64_t n, sycl::buffer& r) override { throw oneapi::mkl::unimplemented( "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); } virtual void generate(const poisson& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](std::int32_t* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, (std::uint32_t*)r_ptr, + n, distr.lambda()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const poisson& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, r_ptr, n, + distr.lambda()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const bits& distr, std::int64_t n, @@ -314,6 +400,8 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } // USM APIs @@ -330,6 +418,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -345,6 +436,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -362,6 +456,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } @@ -377,6 +474,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -392,6 +492,9 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -400,13 +503,17 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_normal, status, engine_, r, n, distr.mean(), distr.stddev()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -414,31 +521,51 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r, n, distr.mean(), distr.stddev()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, float* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal, status, engine_, r, n, distr.mean(), + distr.stddev()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, double* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r, n, distr.mean(), + distr.stddev()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -446,13 +573,17 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r, n, distr.m(), distr.s()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -460,31 +591,51 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r, n, distr.m(), distr.s()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, float* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r, n, distr.m(), + distr.s()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, double* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r, n, distr.m(), + distr.s()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate(const bernoulli& distr, @@ -492,7 +643,7 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { const std::vector& dependencies) override { throw oneapi::mkl::unimplemented( "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); return sycl::event{}; } @@ -501,37 +652,57 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { const std::vector& dependencies) override { throw oneapi::mkl::unimplemented( "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); return sycl::event{}; } virtual sycl::event generate( const poisson& distr, std::int64_t n, std::int32_t* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, (std::uint32_t*)r, n, + distr.lambda()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const poisson& distr, std::int64_t n, std::uint32_t* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "mrg32ka engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, r, n, distr.lambda()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate(const bits& distr, std::int64_t n, std::uint32_t* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate, status, engine_, r, n); }); }); + + increment_internal_offset(n); + + return event; } virtual oneapi::mkl::rng::detail::engine_impl* copy_state() override { @@ -545,11 +716,11 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { virtual void skip_ahead(std::initializer_list num_to_skip) override { throw oneapi::mkl::unimplemented("rng", "skip_ahead", - "initializer list unsupported by cuRAND backend"); + "initializer list unsupported by rocRAND backend"); } virtual void leapfrog(std::uint64_t idx, std::uint64_t stride) override { - throw oneapi::mkl::unimplemented("rng", "leapfrog", "unsupported by cuRAND backend"); + throw oneapi::mkl::unimplemented("rng", "leapfrog", "unsupported by rocRAND backend"); } virtual ~mrg32k3a_impl() override { @@ -559,8 +730,13 @@ class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { private: rocrand_generator engine_; std::uint32_t seed_; + std::uint64_t offset_; + + void increment_internal_offset(std::uint64_t n) { + offset_ += n; + } }; -#else // cuRAND backend is currently not supported on Windows +#else // rocRAND backend is currently not supported on Windows class mrg32k3a_impl : public oneapi::mkl::rng::detail::engine_impl { public: mrg32k3a_impl(sycl::queue queue, std::uint32_t seed) diff --git a/src/rng/backends/rocrand/philox4x32x10.cpp b/src/rng/backends/rocrand/philox4x32x10.cpp index 1b3511a1c..5bc241360 100644 --- a/src/rng/backends/rocrand/philox4x32x10.cpp +++ b/src/rng/backends/rocrand/philox4x32x10.cpp @@ -1,7 +1,7 @@ /******************************************************************************* * Copyright (C) 2022 Heidelberg University, Engineering Mathematics and Computing Lab (EMCL) * and Computing Centre (URZ) - * cuRAND back-end Copyright (c) 2021, The Regents of the University of + * rocRAND back-end Copyright (c) 2021, The Regents of the University of * California, through Lawrence Berkeley National Laboratory (subject to receipt * of any required approvals from the U.S. Dept. of Energy). All rights * reserved. @@ -86,7 +86,7 @@ namespace rocrand { #if !defined(_WIN64) /* - * Note that cuRAND consists of two pieces: a host (CPU) API and a device (GPU) + * Note that rocRAND consists of two pieces: a host (CPU) API and a device (GPU) * API. The host API acts like any standard library; the `rocrand.h' header is * included and the functions can be called as usual. The generator is * instantiated on the host and random numbers can be generated on either the @@ -110,7 +110,9 @@ namespace rocrand { class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { public: philox4x32x10_impl(sycl::queue queue, std::uint64_t seed) - : oneapi::mkl::rng::detail::engine_impl(queue) { + : oneapi::mkl::rng::detail::engine_impl(queue), + seed_(seed), + offset_(0) { rocrand_status status; ROCRAND_CALL(rocrand_create_generator, status, &engine_, ROCRAND_RNG_PSEUDO_PHILOX4_32_10); ROCRAND_CALL(rocrand_set_seed, status, engine_, (unsigned long long)seed); @@ -119,13 +121,19 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { philox4x32x10_impl(sycl::queue queue, std::initializer_list seed) : oneapi::mkl::rng::detail::engine_impl(queue) { throw oneapi::mkl::unimplemented("rng", "philox4x32x10 engine", - "multi-seed unsupported by cuRAND backend"); + "multi-seed unsupported by rocRAND backend"); } philox4x32x10_impl(const philox4x32x10_impl* other) - : oneapi::mkl::rng::detail::engine_impl(*other) { - throw oneapi::mkl::unimplemented("rng", "philox4x32x10 engine", - "copy construction unsupported by cuRAND backend"); + : oneapi::mkl::rng::detail::engine_impl(*other), + seed_(other->seed_), + offset_(other->offset_) { + rocrand_status status; + ROCRAND_CALL(rocrand_create_generator, status, &engine_, ROCRAND_RNG_PSEUDO_PHILOX4_32_10); + ROCRAND_CALL(rocrand_set_seed, status, engine_, (unsigned long long)seed_); + + // Allign this->engine_'s offset state with other->engine_'s offset + skip_ahead(offset_); } // Buffers API @@ -142,6 +150,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -157,6 +168,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -173,6 +187,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } @@ -188,6 +205,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -203,6 +223,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -219,6 +242,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::gaussian< @@ -234,22 +259,42 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal, status, engine_, r_ptr, n, distr.mean(), + distr.stddev()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r_ptr, n, + distr.mean(), distr.stddev()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::lognormal< @@ -265,6 +310,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const oneapi::mkl::rng::lognormal< @@ -280,50 +327,88 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](float* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r_ptr, n, distr.m(), + distr.s()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](double* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r_ptr, n, + distr.m(), distr.s()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const bernoulli& distr, std::int64_t n, sycl::buffer& r) override { throw oneapi::mkl::unimplemented( "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); } virtual void generate(const bernoulli& distr, std::int64_t n, sycl::buffer& r) override { throw oneapi::mkl::unimplemented( "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); } virtual void generate(const poisson& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](std::int32_t* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, (std::uint32_t*)r_ptr, + n, distr.lambda()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const poisson& distr, std::int64_t n, sycl::buffer& r) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + queue_ + .submit([&](sycl::handler& cgh) { + auto acc = r.get_access(cgh); + onemkl_rocrand_host_task(cgh, acc, engine_, [=](std::uint32_t* r_ptr) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, r_ptr, n, + distr.lambda()); + }); + }) + .wait_and_throw(); + + increment_internal_offset(n); } virtual void generate(const bits& distr, std::int64_t n, @@ -337,6 +422,8 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); } // USM APIs @@ -353,6 +440,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -368,6 +458,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp(queue_, distr.a(), distr.b(), n, r); } @@ -385,6 +478,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_int(queue_, distr.a(), distr.b(), n, ib, r); } @@ -400,6 +496,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -415,6 +514,9 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { }); }) .wait_and_throw(); + + increment_internal_offset(n); + return range_transform_fp_accurate(queue_, distr.a(), distr.b(), n, r); } @@ -423,13 +525,17 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_normal, status, engine_, r, n, distr.mean(), distr.stddev()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -437,31 +543,51 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r, n, distr.mean(), distr.stddev()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, float* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal, status, engine_, r, n, distr.mean(), + distr.stddev()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::gaussian& distr, std::int64_t n, double* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_normal_double, status, engine_, r, n, distr.mean(), + distr.stddev()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -469,13 +595,17 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, float* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r, n, distr.m(), distr.s()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( @@ -483,31 +613,51 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { distr, std::int64_t n, double* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r, n, distr.m(), distr.s()); }); }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, float* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal, status, engine_, r, n, distr.m(), + distr.s()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const oneapi::mkl::rng::lognormal& distr, std::int64_t n, double* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_log_normal_double, status, engine_, r, n, distr.m(), + distr.s()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate(const bernoulli& distr, @@ -515,7 +665,7 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { const std::vector& dependencies) override { throw oneapi::mkl::unimplemented( "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); return sycl::event{}; } @@ -524,37 +674,56 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { const std::vector& dependencies) override { throw oneapi::mkl::unimplemented( "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); + "Bernoulli distribution method unsupported by rocRAND backend"); return sycl::event{}; } virtual sycl::event generate( const poisson& distr, std::int64_t n, std::int32_t* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, (std::uint32_t*)r, n, + distr.lambda()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate( const poisson& distr, std::int64_t n, std::uint32_t* r, const std::vector& dependencies) override { - throw oneapi::mkl::unimplemented( - "rng", "philox4x32x10 engine", - "ICDF method not used for pseudorandom generators in cuRAND backend"); - return sycl::event{}; + sycl::event::wait_and_throw(dependencies); + auto event = queue_.submit([&](sycl::handler& cgh) { + onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { + rocrand_status status; + ROCRAND_CALL(rocrand_generate_poisson, status, engine_, r, n, distr.lambda()); + }); + }); + + increment_internal_offset(n); + + return event; } virtual sycl::event generate(const bits& distr, std::int64_t n, std::uint32_t* r, const std::vector& dependencies) override { sycl::event::wait_and_throw(dependencies); - return queue_.submit([&](sycl::handler& cgh) { + auto event = queue_.submit([&](sycl::handler& cgh) { onemkl_rocrand_host_task(cgh, engine_, [=](sycl::interop_handle ih) { rocrand_status status; ROCRAND_CALL(rocrand_generate, status, engine_, r, n); }); }); + + increment_internal_offset(n); + + return event; } virtual oneapi::mkl::rng::detail::engine_impl* copy_state() override { @@ -568,11 +737,11 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { virtual void skip_ahead(std::initializer_list num_to_skip) override { throw oneapi::mkl::unimplemented("rng", "skip_ahead", - "initializer list unsupported by cuRAND backend"); + "initializer list unsupported by rocRAND backend"); } virtual void leapfrog(std::uint64_t idx, std::uint64_t stride) override { - throw oneapi::mkl::unimplemented("rng", "leapfrog", "unsupported by cuRAND backend"); + throw oneapi::mkl::unimplemented("rng", "leapfrog", "unsupported by rocRAND backend"); } virtual ~philox4x32x10_impl() override { @@ -581,8 +750,14 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { private: rocrand_generator engine_; + std::uint64_t seed_; + std::uint64_t offset_; + + void increment_internal_offset(std::uint64_t n) { + offset_ += n; + } }; -#else // cuRAND backend is currently not supported on Windows +#else // rocRAND backend is currently not supported on Windows class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { public: philox4x32x10_impl(sycl::queue queue, std::uint64_t seed) @@ -859,8 +1034,7 @@ class philox4x32x10_impl : public oneapi::mkl::rng::detail::engine_impl { #endif oneapi::mkl::rng::detail::engine_impl* create_philox4x32x10(sycl::queue queue, std::uint64_t seed) { - auto a = new philox4x32x10_impl(queue, seed); - return a; + return new philox4x32x10_impl(queue, seed); } oneapi::mkl::rng::detail::engine_impl* create_philox4x32x10( diff --git a/src/rng/backends/rocrand/rocrand_helper.hpp b/src/rng/backends/rocrand/rocrand_helper.hpp index 6be9269a6..205429ee8 100644 --- a/src/rng/backends/rocrand/rocrand_helper.hpp +++ b/src/rng/backends/rocrand/rocrand_helper.hpp @@ -315,10 +315,10 @@ class rocm_error : virtual public std::runtime_error { } }; -#define HIP_ERROR_FUNC(name, err, ...) \ - err = name(__VA_ARGS__); \ - if (err != HIP_SUCCESS) { \ - throw hip_error(std::string(#name) + std::string(" : "), err); \ +#define HIP_ERROR_FUNC(name, err, ...) \ + err = name(__VA_ARGS__); \ + if (err != HIP_SUCCESS) { \ + throw rocm_error(std::string(#name) + std::string(" : "), err); \ } #define ROCRAND_CALL(func, status, ...) \ diff --git a/src/rng/backends/rocrand/rocrand_task.hpp b/src/rng/backends/rocrand/rocrand_task.hpp index 4646ca342..2588dc901 100644 --- a/src/rng/backends/rocrand/rocrand_task.hpp +++ b/src/rng/backends/rocrand/rocrand_task.hpp @@ -43,6 +43,9 @@ static inline void host_task_internal(H &cgh, A acc, E e, F f) { auto r_ptr = reinterpret_cast( ih.get_native_mem(acc)); f(r_ptr); + + hipError_t err; + HIP_ERROR_FUNC(hipStreamSynchronize, err, stream); }); } @@ -53,6 +56,9 @@ static inline void host_task_internal(H &cgh, E e, F f) { auto stream = ih.get_native_queue(); ROCRAND_CALL(rocrand_set_stream, status, e, stream); f(ih); + + hipError_t err; + HIP_ERROR_FUNC(hipStreamSynchronize, err, stream); }); } #endif diff --git a/src/sparse_blas/CMakeLists.txt b/src/sparse_blas/CMakeLists.txt new file mode 100644 index 000000000..b93902f49 --- /dev/null +++ b/src/sparse_blas/CMakeLists.txt @@ -0,0 +1,48 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +add_subdirectory(backends) + +if(BUILD_SHARED_LIBS) + add_library(onemkl_sparse_blas OBJECT) + target_sources(onemkl_sparse_blas PRIVATE sparse_blas_loader.cpp) + target_include_directories(onemkl_sparse_blas + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${PROJECT_SOURCE_DIR}/src/include + ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} + $ + ) + + target_compile_options(onemkl_sparse_blas PRIVATE ${ONEMKL_BUILD_COPT}) + + set_target_properties(onemkl_sparse_blas PROPERTIES + POSITION_INDEPENDENT_CODE ON + ) + if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET onemkl_sparse_blas SOURCES sparse_blas_loader.cpp) + else() + target_link_libraries(onemkl_sparse_blas PUBLIC ONEMKL::SYCL::SYCL) + endif() + + include(WarningsUtils) + target_link_libraries(onemkl_sparse_blas PRIVATE onemkl_warnings) + +endif() diff --git a/src/sparse_blas/backends/CMakeLists.txt b/src/sparse_blas/backends/CMakeLists.txt new file mode 100644 index 000000000..ef606c6e1 --- /dev/null +++ b/src/sparse_blas/backends/CMakeLists.txt @@ -0,0 +1,29 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +add_custom_target(onemkl_backend_libs_sparse_blas) +add_dependencies(onemkl_backend_libs onemkl_backend_libs_sparse_blas) + +if(ENABLE_MKLCPU_BACKEND) + add_subdirectory(mklcpu) +endif() + +if(ENABLE_MKLGPU_BACKEND) + add_subdirectory(mklgpu) +endif() diff --git a/src/sparse_blas/backends/backend_wrappers.cxx b/src/sparse_blas/backends/backend_wrappers.cxx new file mode 100644 index 000000000..2c8161249 --- /dev/null +++ b/src/sparse_blas/backends/backend_wrappers.cxx @@ -0,0 +1,85 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +/* +This file lists functions matching those required by sparse_blas_function_table_t in +src/sparse_blas/function_table.hpp. + +To use this: + +#define WRAPPER_VERSION +#define BACKEND + +extern "C" sparse_blas_function_table_t mkl_sparse_blas_table = { + WRAPPER_VERSION, +#include "sparse_blas/backends/backend_wrappers.cxx" +}; + +Changes to this file should be matched to changes in sparse_blas/function_table.hpp. The required +function template instantiations must be added to backend_sparse_blas_instantiations.cxx. +*/ + +// clang-format off +oneapi::mkl::sparse::BACKEND::init_matrix_handle, +oneapi::mkl::sparse::BACKEND::release_matrix_handle, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::set_csr_data, +oneapi::mkl::sparse::BACKEND::optimize_gemm, +oneapi::mkl::sparse::BACKEND::optimize_gemm, +oneapi::mkl::sparse::BACKEND::optimize_gemv, +oneapi::mkl::sparse::BACKEND::optimize_trsv, +oneapi::mkl::sparse::BACKEND::gemv, +oneapi::mkl::sparse::BACKEND::gemv, +oneapi::mkl::sparse::BACKEND::gemv, +oneapi::mkl::sparse::BACKEND::gemv, +oneapi::mkl::sparse::BACKEND::gemv, +oneapi::mkl::sparse::BACKEND::gemv, +oneapi::mkl::sparse::BACKEND::gemv, +oneapi::mkl::sparse::BACKEND::gemv, +oneapi::mkl::sparse::BACKEND::trsv, +oneapi::mkl::sparse::BACKEND::trsv, +oneapi::mkl::sparse::BACKEND::trsv, +oneapi::mkl::sparse::BACKEND::trsv, +oneapi::mkl::sparse::BACKEND::trsv, +oneapi::mkl::sparse::BACKEND::trsv, +oneapi::mkl::sparse::BACKEND::trsv, +oneapi::mkl::sparse::BACKEND::trsv, +oneapi::mkl::sparse::BACKEND::gemm, +oneapi::mkl::sparse::BACKEND::gemm, +oneapi::mkl::sparse::BACKEND::gemm, +oneapi::mkl::sparse::BACKEND::gemm, +oneapi::mkl::sparse::BACKEND::gemm, +oneapi::mkl::sparse::BACKEND::gemm, +oneapi::mkl::sparse::BACKEND::gemm, +oneapi::mkl::sparse::BACKEND::gemm, + // clang-format on diff --git a/src/sparse_blas/backends/mkl_common/mkl_basic.cxx b/src/sparse_blas/backends/mkl_common/mkl_basic.cxx new file mode 100644 index 000000000..fd3b1563a --- /dev/null +++ b/src/sparse_blas/backends/mkl_common/mkl_basic.cxx @@ -0,0 +1,62 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +void init_matrix_handle(sycl::queue & /*queue*/, detail::matrix_handle **p_handle) { + oneapi::mkl::sparse::init_matrix_handle(detail::get_handle(p_handle)); +} + +sycl::event release_matrix_handle(sycl::queue &queue, detail::matrix_handle **p_handle, + const std::vector &dependencies) { + return oneapi::mkl::sparse::release_matrix_handle(queue, detail::get_handle(p_handle), + dependencies); +} + +template +std::enable_if_t> set_csr_data( + sycl::queue &queue, detail::matrix_handle *handle, intType num_rows, intType num_cols, + intType /*nnz*/, index_base index, sycl::buffer &row_ptr, + sycl::buffer &col_ind, sycl::buffer &val) { + oneapi::mkl::sparse::set_csr_data(queue, detail::get_handle(handle), num_rows, num_cols, index, + row_ptr, col_ind, val); +} + +template +std::enable_if_t, sycl::event> set_csr_data( + sycl::queue &queue, detail::matrix_handle *handle, intType num_rows, intType num_cols, + intType /*nnz*/, index_base index, intType *row_ptr, intType *col_ind, fpType *val, + const std::vector &dependencies) { + return oneapi::mkl::sparse::set_csr_data(queue, detail::get_handle(handle), num_rows, num_cols, + index, row_ptr, col_ind, val, dependencies); +} + +#define INSTANTIATE_SET_CSR_DATA(FP_TYPE, INT_TYPE) \ + template std::enable_if_t> \ + set_csr_data( \ + sycl::queue & queue, detail::matrix_handle * handle, INT_TYPE num_rows, INT_TYPE num_cols, \ + INT_TYPE nnz, index_base index, sycl::buffer & row_ptr, \ + sycl::buffer & col_ind, sycl::buffer & val); \ + template std::enable_if_t, sycl::event> \ + set_csr_data(sycl::queue & queue, detail::matrix_handle * handle, \ + INT_TYPE num_rows, INT_TYPE num_cols, INT_TYPE nnz, \ + index_base index, INT_TYPE * row_ptr, INT_TYPE * col_ind, \ + FP_TYPE * val, const std::vector &dependencies) + +FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_SET_CSR_DATA); + +#undef INSTANTIATE_SET_CSR_DATA diff --git a/src/sparse_blas/backends/mkl_common/mkl_helper.hpp b/src/sparse_blas/backends/mkl_common/mkl_helper.hpp new file mode 100644 index 000000000..da5235ee0 --- /dev/null +++ b/src/sparse_blas/backends/mkl_common/mkl_helper.hpp @@ -0,0 +1,56 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +// MKLCPU and MKLGPU backends include +// This include defines its own oneapi::mkl::sparse namespace with some of the types that are used here: matrix_handle_t, index_base, transpose, uolo, diag. +#include + +// Includes are set up so that oneapi::mkl::sparse namespace refers to the MKLCPU and MKLGPU backends namespace (oneMKL product) +// in this file. +// oneapi::mkl::sparse::detail namespace refers to the oneMKL interface namespace. + +#include "oneapi/mkl/sparse_blas/detail/helper_types.hpp" + +namespace oneapi::mkl::sparse::detail { + +inline auto get_handle(detail::matrix_handle **handle) { + return reinterpret_cast(handle); +} + +inline auto get_handle(detail::matrix_handle *handle) { + return reinterpret_cast(handle); +} + +} // namespace oneapi::mkl::sparse::detail + +#define FOR_EACH_FP_TYPE(INSTANTIATE_MACRO) \ + INSTANTIATE_MACRO(float); \ + INSTANTIATE_MACRO(double); \ + INSTANTIATE_MACRO(std::complex); \ + INSTANTIATE_MACRO(std::complex) + +#define FOR_EACH_FP_AND_INT_TYPE_HELPER(INSTANTIATE_MACRO, INT_TYPE) \ + INSTANTIATE_MACRO(float, INT_TYPE); \ + INSTANTIATE_MACRO(double, INT_TYPE); \ + INSTANTIATE_MACRO(std::complex, INT_TYPE); \ + INSTANTIATE_MACRO(std::complex, INT_TYPE) + +#define FOR_EACH_FP_AND_INT_TYPE(INSTANTIATE_MACRO) \ + FOR_EACH_FP_AND_INT_TYPE_HELPER(INSTANTIATE_MACRO, std::int32_t); \ + FOR_EACH_FP_AND_INT_TYPE_HELPER(INSTANTIATE_MACRO, std::int64_t) diff --git a/src/sparse_blas/backends/mkl_common/mkl_operations.cxx b/src/sparse_blas/backends/mkl_common/mkl_operations.cxx new file mode 100644 index 000000000..ba6960341 --- /dev/null +++ b/src/sparse_blas/backends/mkl_common/mkl_operations.cxx @@ -0,0 +1,170 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +sycl::event optimize_gemm(sycl::queue& queue, transpose /*transpose_A*/, + detail::matrix_handle* /*handle*/, + const std::vector& dependencies) { + // TODO: Call to optimize_gemm with 2024.1 oneMKL release + // Return an event depending on the dependencies + return queue.submit([=](sycl::handler& cgh) { + cgh.depends_on(dependencies); + cgh.host_task([=]() { /* Empty kernel */ }); + }); +} + +sycl::event optimize_gemm(sycl::queue& queue, transpose /*transpose_A*/, transpose /*transpose_B*/, + layout /*dense_matrix_layout*/, const std::int64_t /*columns*/, + detail::matrix_handle* /*handle*/, + const std::vector& dependencies) { + // TODO: Call to optimize_gemm with 2024.1 oneMKL release + // Return an event depending on the dependencies + return queue.submit([=](sycl::handler& cgh) { + cgh.depends_on(dependencies); + cgh.host_task([=]() { /* Empty kernel */ }); + }); +} + +sycl::event optimize_gemv(sycl::queue& queue, transpose transpose_val, + detail::matrix_handle* handle, + const std::vector& dependencies) { + return oneapi::mkl::sparse::optimize_gemv(queue, transpose_val, detail::get_handle(handle), + dependencies); +} + +sycl::event optimize_trsv(sycl::queue& queue, uplo uplo_val, transpose transpose_val, diag diag_val, + detail::matrix_handle* handle, + const std::vector& dependencies) { + // TODO: Remove this if condition once Intel oneMKL adds support for trans/conjtrans to optimize_trsv + if (transpose_val != transpose::nontrans) { + throw mkl::unimplemented("sparse_blas/backends/mkl", __FUNCTION__, + "Transposed or conjugate trsv is not supported"); + } + return oneapi::mkl::sparse::optimize_trsv(queue, uplo_val, transpose_val, diag_val, + detail::get_handle(handle), dependencies); +} + +template +std::enable_if_t> gemv( + sycl::queue& queue, transpose transpose_val, const fpType alpha, + detail::matrix_handle* A_handle, sycl::buffer& x, const fpType beta, + sycl::buffer& y) { + oneapi::mkl::sparse::gemv(queue, transpose_val, alpha, detail::get_handle(A_handle), x, beta, y); +} + +template +std::enable_if_t, sycl::event> gemv( + sycl::queue& queue, transpose transpose_val, const fpType alpha, + detail::matrix_handle* A_handle, const fpType* x, const fpType beta, fpType* y, + const std::vector& dependencies) { + return oneapi::mkl::sparse::gemv(queue, transpose_val, alpha, detail::get_handle(A_handle), x, beta, y, + dependencies); +} + +template +std::enable_if_t> trsv(sycl::queue& queue, uplo uplo_val, + transpose transpose_val, diag diag_val, + detail::matrix_handle* A_handle, + sycl::buffer& x, + sycl::buffer& y) { + // TODO: Remove this if condition once Intel oneMKL adds support for trans/conjtrans to trsv + if (transpose_val != transpose::nontrans) { + throw mkl::unimplemented("sparse_blas/backends/mkl", __FUNCTION__, + "Transposed or conjugate trsv is not supported"); + } + oneapi::mkl::sparse::trsv(queue, uplo_val, transpose_val, diag_val, + detail::get_handle(A_handle), x, y); +} + +template +std::enable_if_t, sycl::event> trsv( + sycl::queue& queue, uplo uplo_val, transpose transpose_val, diag diag_val, + detail::matrix_handle* A_handle, const fpType* x, fpType* y, + const std::vector& dependencies) { + // TODO: Remove this if condition once Intel oneMKL adds support for trans/conjtrans to trsv + if (transpose_val != transpose::nontrans) { + throw mkl::unimplemented("sparse_blas/backends/mkl", __FUNCTION__, + "Transposed or conjugate trsv is not supported"); + } + // TODO: Remove const_cast in future oneMKL release + return oneapi::mkl::sparse::trsv(queue, uplo_val, transpose_val, diag_val, + detail::get_handle(A_handle), const_cast(x), y, + dependencies); +} + +template +std::enable_if_t> gemm( + sycl::queue& queue, layout dense_matrix_layout, transpose transpose_A, transpose transpose_B, + const fpType alpha, detail::matrix_handle* A_handle, sycl::buffer& B, + const std::int64_t columns, const std::int64_t ldb, const fpType beta, + sycl::buffer& C, const std::int64_t ldc) { + oneapi::mkl::sparse::gemm(queue, dense_matrix_layout, transpose_A, transpose_B, alpha, + detail::get_handle(A_handle), B, columns, ldb, beta, C, ldc); +} + +template +std::enable_if_t, sycl::event> gemm( + sycl::queue& queue, layout dense_matrix_layout, transpose transpose_A, transpose transpose_B, + const fpType alpha, detail::matrix_handle* A_handle, const fpType* B, + const std::int64_t columns, const std::int64_t ldb, const fpType beta, fpType* C, + const std::int64_t ldc, const std::vector& dependencies) { + // TODO: Remove const_cast in future oneMKL release + return oneapi::mkl::sparse::gemm(queue, dense_matrix_layout, transpose_A, transpose_B, alpha, + detail::get_handle(A_handle), const_cast(B), columns, + ldb, beta, C, ldc, dependencies); +} + +#define INSTANTIATE_GEMV(FP_TYPE) \ + template std::enable_if_t> gemv( \ + sycl::queue& queue, transpose transpose_val, const FP_TYPE alpha, \ + detail::matrix_handle* A_handle, sycl::buffer& x, const FP_TYPE beta, \ + sycl::buffer& y); \ + template std::enable_if_t, sycl::event> gemv( \ + sycl::queue& queue, transpose transpose_val, const FP_TYPE alpha, \ + detail::matrix_handle* A_handle, const FP_TYPE* x, const FP_TYPE beta, FP_TYPE* y, \ + const std::vector& dependencies) + +#define INSTANTIATE_TRSV(FP_TYPE) \ + template std::enable_if_t> trsv( \ + sycl::queue& queue, uplo uplo_val, transpose transpose_val, diag diag_val, \ + detail::matrix_handle* A_handle, sycl::buffer& x, \ + sycl::buffer& y); \ + template std::enable_if_t, sycl::event> trsv( \ + sycl::queue& queue, uplo uplo_val, transpose transpose_val, diag diag_val, \ + detail::matrix_handle* A_handle, const FP_TYPE* x, FP_TYPE* y, \ + const std::vector& dependencies) + +#define INSTANTIATE_GEMM(FP_TYPE) \ + template std::enable_if_t> gemm( \ + sycl::queue& queue, layout dense_matrix_layout, transpose transpose_A, \ + transpose transpose_B, const FP_TYPE alpha, detail::matrix_handle* A_handle, \ + sycl::buffer& B, const std::int64_t columns, const std::int64_t ldb, \ + const FP_TYPE beta, sycl::buffer& C, const std::int64_t ldc); \ + template std::enable_if_t, sycl::event> gemm( \ + sycl::queue& queue, layout dense_matrix_layout, transpose transpose_A, \ + transpose transpose_B, const FP_TYPE alpha, detail::matrix_handle* A_handle, \ + const FP_TYPE* B, const std::int64_t columns, const std::int64_t ldb, const FP_TYPE beta, \ + FP_TYPE* C, const std::int64_t ldc, const std::vector& dependencies) + +FOR_EACH_FP_TYPE(INSTANTIATE_GEMV); +FOR_EACH_FP_TYPE(INSTANTIATE_TRSV); +FOR_EACH_FP_TYPE(INSTANTIATE_GEMM); + +#undef INSTANTIATE_GEMV +#undef INSTANTIATE_TRSV +#undef INSTANTIATE_GEMM diff --git a/src/sparse_blas/backends/mklcpu/CMakeLists.txt b/src/sparse_blas/backends/mklcpu/CMakeLists.txt new file mode 100644 index 000000000..cfcf9cf3d --- /dev/null +++ b/src/sparse_blas/backends/mklcpu/CMakeLists.txt @@ -0,0 +1,82 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemkl_sparse_blas_mklcpu) +set(LIB_OBJ ${LIB_NAME}_obj) + +include(WarningsUtils) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + mklcpu_basic.cpp + mklcpu_operations.cpp + $<$: mklcpu_wrappers.cpp> +) +add_dependencies(onemkl_backend_libs_sparse_blas ${LIB_NAME}) + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + +if(TARGET MKL::MKL_SYCL::SPARSE) + target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC MKL::MKL_SYCL::SPARSE + PRIVATE onemkl_warnings + ) +else() + target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC MKL::MKL_DPCPP + PRIVATE onemkl_warnings + ) +endif() + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +#Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/dft/backends/mklcpu/compute_signature.cpp b/src/sparse_blas/backends/mklcpu/mklcpu_basic.cpp similarity index 72% rename from src/dft/backends/mklcpu/compute_signature.cpp rename to src/sparse_blas/backends/mklcpu/mklcpu_basic.cpp index 2efe2d413..9ab29ee92 100644 --- a/src/dft/backends/mklcpu/compute_signature.cpp +++ b/src/sparse_blas/backends/mklcpu/mklcpu_basic.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright Codeplay Software Ltd +* Copyright 2023 Codeplay Software Ltd. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,12 @@ * SPDX-License-Identifier: Apache-2.0 *******************************************************************************/ -#include "oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp" +#include "../mkl_common/mkl_helper.hpp" -#define BACKEND mklcpu +#include "oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp" -#include "dft/backends/backend_compute_signature.cxx" +namespace oneapi::mkl::sparse::mklcpu { + +#include "../mkl_common/mkl_basic.cxx" + +} // namespace oneapi::mkl::sparse::mklcpu diff --git a/src/sparse_blas/backends/mklcpu/mklcpu_operations.cpp b/src/sparse_blas/backends/mklcpu/mklcpu_operations.cpp new file mode 100644 index 000000000..e636b1816 --- /dev/null +++ b/src/sparse_blas/backends/mklcpu/mklcpu_operations.cpp @@ -0,0 +1,28 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "../mkl_common/mkl_helper.hpp" + +#include "oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp" + +namespace oneapi::mkl::sparse::mklcpu { + +#include "../mkl_common/mkl_operations.cxx" + +} // namespace oneapi::mkl::sparse::mklcpu diff --git a/src/sparse_blas/backends/mklcpu/mklcpu_wrappers.cpp b/src/sparse_blas/backends/mklcpu/mklcpu_wrappers.cpp new file mode 100644 index 000000000..40f75c60c --- /dev/null +++ b/src/sparse_blas/backends/mklcpu/mklcpu_wrappers.cpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/sparse_blas/types.hpp" + +#include "oneapi/mkl/sparse_blas/detail/mklcpu/onemkl_sparse_blas_mklcpu.hpp" + +#include "sparse_blas/function_table.hpp" + +#define WRAPPER_VERSION 1 +#define BACKEND mklcpu + +extern "C" sparse_blas_function_table_t mkl_sparse_blas_table = { + WRAPPER_VERSION, +#include "sparse_blas/backends/backend_wrappers.cxx" +}; diff --git a/src/sparse_blas/backends/mklgpu/CMakeLists.txt b/src/sparse_blas/backends/mklgpu/CMakeLists.txt new file mode 100644 index 000000000..a31794547 --- /dev/null +++ b/src/sparse_blas/backends/mklgpu/CMakeLists.txt @@ -0,0 +1,82 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(LIB_NAME onemkl_sparse_blas_mklgpu) +set(LIB_OBJ ${LIB_NAME}_obj) + +include(WarningsUtils) + +add_library(${LIB_NAME}) +add_library(${LIB_OBJ} OBJECT + mklgpu_basic.cpp + mklgpu_operations.cpp + $<$: mklgpu_wrappers.cpp> +) +add_dependencies(onemkl_backend_libs_sparse_blas ${LIB_NAME}) + +target_include_directories(${LIB_OBJ} + PRIVATE ${PROJECT_SOURCE_DIR}/include + ${PROJECT_SOURCE_DIR}/src + ${CMAKE_BINARY_DIR}/bin + ${ONEMKL_GENERATED_INCLUDE_PATH} +) + +target_compile_options(${LIB_OBJ} PRIVATE ${ONEMKL_BUILD_COPT}) + +if(TARGET MKL::MKL_SYCL::SPARSE) + target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC MKL::MKL_SYCL::SPARSE + PRIVATE onemkl_warnings + ) +else() + target_link_libraries(${LIB_OBJ} + PUBLIC ONEMKL::SYCL::SYCL + PUBLIC MKL::MKL_DPCPP + PRIVATE onemkl_warnings + ) +endif() + +set_target_properties(${LIB_OBJ} PROPERTIES + POSITION_INDEPENDENT_CODE ON +) +target_link_libraries(${LIB_NAME} PUBLIC ${LIB_OBJ}) + +#Set oneMKL libraries as not transitive for dynamic +if(BUILD_SHARED_LIBS) + set_target_properties(${LIB_NAME} PROPERTIES + INTERFACE_LINK_LIBRARIES ONEMKL::SYCL::SYCL + ) +endif() + +# Add major version to the library +set_target_properties(${LIB_NAME} PROPERTIES + SOVERSION ${PROJECT_VERSION_MAJOR} +) + +# Add dependencies rpath to the library +list(APPEND CMAKE_BUILD_RPATH $) + +# Add the library to install package +install(TARGETS ${LIB_OBJ} EXPORT oneMKLTargets) +install(TARGETS ${LIB_NAME} EXPORT oneMKLTargets + RUNTIME DESTINATION bin + ARCHIVE DESTINATION lib + LIBRARY DESTINATION lib +) diff --git a/src/dft/backends/mklgpu/compute_signature.cpp b/src/sparse_blas/backends/mklgpu/mklgpu_basic.cpp similarity index 72% rename from src/dft/backends/mklgpu/compute_signature.cpp rename to src/sparse_blas/backends/mklgpu/mklgpu_basic.cpp index 9027b012b..8df24f8da 100644 --- a/src/dft/backends/mklgpu/compute_signature.cpp +++ b/src/sparse_blas/backends/mklgpu/mklgpu_basic.cpp @@ -1,5 +1,5 @@ /******************************************************************************* -* Copyright Codeplay Software Ltd +* Copyright 2023 Codeplay Software Ltd. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,8 +17,12 @@ * SPDX-License-Identifier: Apache-2.0 *******************************************************************************/ -#include "oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp" +#include "../mkl_common/mkl_helper.hpp" -#define BACKEND mklgpu +#include "oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp" -#include "dft/backends/backend_compute_signature.cxx" +namespace oneapi::mkl::sparse::mklgpu { + +#include "../mkl_common/mkl_basic.cxx" + +} // namespace oneapi::mkl::sparse::mklgpu diff --git a/src/sparse_blas/backends/mklgpu/mklgpu_operations.cpp b/src/sparse_blas/backends/mklgpu/mklgpu_operations.cpp new file mode 100644 index 000000000..439dc4eea --- /dev/null +++ b/src/sparse_blas/backends/mklgpu/mklgpu_operations.cpp @@ -0,0 +1,28 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "../mkl_common/mkl_helper.hpp" + +#include "oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp" + +namespace oneapi::mkl::sparse::mklgpu { + +#include "../mkl_common/mkl_operations.cxx" + +} // namespace oneapi::mkl::sparse::mklgpu diff --git a/src/sparse_blas/backends/mklgpu/mklgpu_wrappers.cpp b/src/sparse_blas/backends/mklgpu/mklgpu_wrappers.cpp new file mode 100644 index 000000000..346b13540 --- /dev/null +++ b/src/sparse_blas/backends/mklgpu/mklgpu_wrappers.cpp @@ -0,0 +1,32 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/sparse_blas/types.hpp" + +#include "oneapi/mkl/sparse_blas/detail/mklgpu/onemkl_sparse_blas_mklgpu.hpp" + +#include "sparse_blas/function_table.hpp" + +#define WRAPPER_VERSION 1 +#define BACKEND mklgpu + +extern "C" sparse_blas_function_table_t mkl_sparse_blas_table = { + WRAPPER_VERSION, +#include "sparse_blas/backends/backend_wrappers.cxx" +}; diff --git a/src/sparse_blas/function_table.hpp b/src/sparse_blas/function_table.hpp new file mode 100644 index 000000000..57279fb3f --- /dev/null +++ b/src/sparse_blas/function_table.hpp @@ -0,0 +1,109 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* (*Licensed under the Apache License, Version 2.0 )(the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_FUNCTION_TABLE_HPP_ +#define _ONEMKL_SPARSE_BLAS_FUNCTION_TABLE_HPP_ + +#include "oneapi/mkl/sparse_blas/types.hpp" +#include "sparse_blas/macros.hpp" + +#define DEFINE_SET_CSR_DATA(FP_TYPE, FP_SUFFIX, INT_TYPE, INT_SUFFIX) \ + void (*set_csr_data_buffer##FP_SUFFIX##INT_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::matrix_handle_t handle, INT_TYPE num_rows, \ + INT_TYPE num_cols, INT_TYPE nnz, oneapi::mkl::index_base index, \ + sycl::buffer & row_ptr, sycl::buffer & col_ind, \ + sycl::buffer & val); \ + sycl::event (*set_csr_data_usm##FP_SUFFIX##INT_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::sparse::matrix_handle_t handle, INT_TYPE num_rows, \ + INT_TYPE num_cols, INT_TYPE nnz, oneapi::mkl::index_base index, INT_TYPE * row_ptr, \ + INT_TYPE * col_ind, FP_TYPE * val, const std::vector &dependencies) + +#define DEFINE_GEMV(FP_TYPE, FP_SUFFIX) \ + void (*gemv_buffer##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::transpose transpose_val, const FP_TYPE alpha, \ + oneapi::mkl::sparse::matrix_handle_t A_handle, sycl::buffer &x, \ + const FP_TYPE beta, sycl::buffer &y); \ + sycl::event (*gemv_usm##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::transpose transpose_val, const FP_TYPE alpha, \ + oneapi::mkl::sparse::matrix_handle_t A_handle, const FP_TYPE *x, const FP_TYPE beta, \ + FP_TYPE *y, const std::vector &dependencies) + +#define DEFINE_TRSV(FP_TYPE, FP_SUFFIX) \ + void (*trsv_buffer##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, \ + oneapi::mkl::diag diag_val, oneapi::mkl::sparse::matrix_handle_t A_handle, \ + sycl::buffer & x, sycl::buffer & y); \ + sycl::event (*trsv_usm##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, \ + oneapi::mkl::diag diag_val, oneapi::mkl::sparse::matrix_handle_t A_handle, \ + const FP_TYPE *x, FP_TYPE *y, const std::vector &dependencies) + +#define DEFINE_GEMM(FP_TYPE, FP_SUFFIX) \ + void (*gemm_buffer##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::layout dense_matrix_layout, \ + oneapi::mkl::transpose transpose_A, oneapi::mkl::transpose transpose_B, \ + const FP_TYPE alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, \ + sycl::buffer &B, const std::int64_t columns, const std::int64_t ldb, \ + const FP_TYPE beta, sycl::buffer &C, const std::int64_t ldc); \ + sycl::event (*gemm_usm##FP_SUFFIX)( \ + sycl::queue & queue, oneapi::mkl::layout dense_matrix_layout, \ + oneapi::mkl::transpose transpose_A, oneapi::mkl::transpose transpose_B, \ + const FP_TYPE alpha, oneapi::mkl::sparse::matrix_handle_t A_handle, const FP_TYPE *B, \ + const std::int64_t columns, const std::int64_t ldb, const FP_TYPE beta, FP_TYPE *C, \ + const std::int64_t ldc, const std::vector &dependencies) + +typedef struct { + int version; + void (*init_matrix_handle)(sycl::queue &queue, oneapi::mkl::sparse::matrix_handle_t *p_handle); + + sycl::event (*release_matrix_handle)(sycl::queue &queue, + oneapi::mkl::sparse::matrix_handle_t *p_handle, + const std::vector &dependencies); + + FOR_EACH_FP_AND_INT_TYPE(DEFINE_SET_CSR_DATA); + + // optimize_* + sycl::event (*optimize_gemm_v1)(sycl::queue &queue, oneapi::mkl::transpose transpose_A, + oneapi::mkl::sparse::matrix_handle_t handle, + const std::vector &dependencies); + sycl::event (*optimize_gemm_v2)(sycl::queue &queue, oneapi::mkl::transpose transpose_A, + oneapi::mkl::transpose transpose_B, + oneapi::mkl::layout dense_matrix_layout, + const std::int64_t columns, + oneapi::mkl::sparse::matrix_handle_t handle, + const std::vector &dependencies); + sycl::event (*optimize_gemv)(sycl::queue &queue, oneapi::mkl::transpose transpose_val, + oneapi::mkl::sparse::matrix_handle_t handle, + const std::vector &dependencies); + sycl::event (*optimize_trsv)(sycl::queue &queue, oneapi::mkl::uplo uplo_val, + oneapi::mkl::transpose transpose_val, oneapi::mkl::diag diag_val, + oneapi::mkl::sparse::matrix_handle_t handle, + const std::vector &dependencies); + + FOR_EACH_FP_TYPE(DEFINE_GEMV); + FOR_EACH_FP_TYPE(DEFINE_TRSV); + FOR_EACH_FP_TYPE(DEFINE_GEMM); +} sparse_blas_function_table_t; + +#undef DEFINE_SET_CSR_DATA +#undef DEFINE_GEMV +#undef DEFINE_TRSV +#undef DEFINE_GEMM + +#endif // _ONEMKL_SPARSE_BLAS_FUNCTION_TABLE_HPP_ diff --git a/src/sparse_blas/macros.hpp b/src/sparse_blas/macros.hpp new file mode 100644 index 000000000..a4ef88e35 --- /dev/null +++ b/src/sparse_blas/macros.hpp @@ -0,0 +1,39 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* (*Licensed under the Apache License, Version 2.0 )(the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _ONEMKL_SPARSE_BLAS_MACROS_HPP_ +#define _ONEMKL_SPARSE_BLAS_MACROS_HPP_ + +#define FOR_EACH_FP_TYPE(DEFINE_MACRO) \ + DEFINE_MACRO(float, _rf); \ + DEFINE_MACRO(double, _rd); \ + DEFINE_MACRO(std::complex, _cf); \ + DEFINE_MACRO(std::complex, _cd) + +#define FOR_EACH_FP_AND_INT_TYPE_HELPER(DEFINE_MACRO, INT_TYPE, INT_SUFFIX) \ + DEFINE_MACRO(float, _rf, INT_TYPE, INT_SUFFIX); \ + DEFINE_MACRO(double, _rd, INT_TYPE, INT_SUFFIX); \ + DEFINE_MACRO(std::complex, _cf, INT_TYPE, INT_SUFFIX); \ + DEFINE_MACRO(std::complex, _cd, INT_TYPE, INT_SUFFIX) + +#define FOR_EACH_FP_AND_INT_TYPE(DEFINE_MACRO) \ + FOR_EACH_FP_AND_INT_TYPE_HELPER(DEFINE_MACRO, std::int32_t, _i32); \ + FOR_EACH_FP_AND_INT_TYPE_HELPER(DEFINE_MACRO, std::int64_t, _i64) + +#endif // _ONEMKL_SPARSE_BLAS_MACROS_HPP_ diff --git a/src/sparse_blas/sparse_blas_loader.cpp b/src/sparse_blas/sparse_blas_loader.cpp new file mode 100644 index 000000000..95da6df9c --- /dev/null +++ b/src/sparse_blas/sparse_blas_loader.cpp @@ -0,0 +1,162 @@ +/******************************************************************************* +* Copyright 2023 Codeplay Software Ltd. +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "oneapi/mkl/sparse_blas/detail/sparse_blas_rt.hpp" + +#include "function_table_initializer.hpp" +#include "sparse_blas/function_table.hpp" +#include "sparse_blas/macros.hpp" +#include "oneapi/mkl/detail/get_device_id.hpp" + +namespace oneapi::mkl::sparse { + +static oneapi::mkl::detail::table_initializer + function_tables; + +void init_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle) { + auto libkey = get_device_id(queue); + function_tables[libkey].init_matrix_handle(queue, p_handle); +} + +sycl::event release_matrix_handle(sycl::queue &queue, matrix_handle_t *p_handle, + const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].release_matrix_handle(queue, p_handle, dependencies); +} + +#define DEFINE_SET_CSR_DATA(FP_TYPE, FP_SUFFIX, INT_TYPE, INT_SUFFIX) \ + template <> \ + void set_csr_data(sycl::queue &queue, matrix_handle_t handle, INT_TYPE num_rows, \ + INT_TYPE num_cols, INT_TYPE nnz, index_base index, \ + sycl::buffer &row_ptr, sycl::buffer &col_ind, \ + sycl::buffer &val) { \ + auto libkey = get_device_id(queue); \ + function_tables[libkey].set_csr_data_buffer##FP_SUFFIX##INT_SUFFIX( \ + queue, handle, num_rows, num_cols, nnz, index, row_ptr, col_ind, val); \ + } \ + template <> \ + sycl::event set_csr_data(sycl::queue &queue, matrix_handle_t handle, INT_TYPE num_rows, \ + INT_TYPE num_cols, INT_TYPE nnz, index_base index, INT_TYPE *row_ptr, \ + INT_TYPE *col_ind, FP_TYPE *val, \ + const std::vector &dependencies) { \ + auto libkey = get_device_id(queue); \ + return function_tables[libkey].set_csr_data_usm##FP_SUFFIX##INT_SUFFIX( \ + queue, handle, num_rows, num_cols, nnz, index, row_ptr, col_ind, val, dependencies); \ + } + +FOR_EACH_FP_AND_INT_TYPE(DEFINE_SET_CSR_DATA) +#undef DEFINE_SET_CSR_DATA + +sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, matrix_handle_t handle, + const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].optimize_gemm_v1(queue, transpose_A, handle, dependencies); +} + +sycl::event optimize_gemm(sycl::queue &queue, transpose transpose_A, transpose transpose_B, + layout dense_matrix_layout, const std::int64_t columns, + matrix_handle_t handle, const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].optimize_gemm_v2( + queue, transpose_A, transpose_B, dense_matrix_layout, columns, handle, dependencies); +} + +sycl::event optimize_gemv(sycl::queue &queue, transpose transpose_val, matrix_handle_t handle, + const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].optimize_gemv(queue, transpose_val, handle, dependencies); +} + +sycl::event optimize_trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, + matrix_handle_t handle, const std::vector &dependencies) { + auto libkey = get_device_id(queue); + return function_tables[libkey].optimize_trsv(queue, uplo_val, transpose_val, diag_val, handle, + dependencies); +} + +#define DEFINE_GEMV(FP_TYPE, FP_SUFFIX) \ + template <> \ + void gemv(sycl::queue &queue, transpose transpose_val, const FP_TYPE alpha, \ + matrix_handle_t A_handle, sycl::buffer &x, const FP_TYPE beta, \ + sycl::buffer &y) { \ + auto libkey = get_device_id(queue); \ + function_tables[libkey].gemv_buffer##FP_SUFFIX(queue, transpose_val, alpha, A_handle, x, \ + beta, y); \ + } \ + template <> \ + sycl::event gemv(sycl::queue &queue, transpose transpose_val, const FP_TYPE alpha, \ + matrix_handle_t A_handle, const FP_TYPE *x, const FP_TYPE beta, FP_TYPE *y, \ + const std::vector &dependencies) { \ + auto libkey = get_device_id(queue); \ + return function_tables[libkey].gemv_usm##FP_SUFFIX(queue, transpose_val, alpha, A_handle, \ + x, beta, y, dependencies); \ + } + +FOR_EACH_FP_TYPE(DEFINE_GEMV) +#undef DEFINE_GEMV + +#define DEFINE_TRSV(FP_TYPE, FP_SUFFIX) \ + template <> \ + void trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, \ + matrix_handle_t A_handle, sycl::buffer &x, \ + sycl::buffer &y) { \ + auto libkey = get_device_id(queue); \ + function_tables[libkey].trsv_buffer##FP_SUFFIX(queue, uplo_val, transpose_val, diag_val, \ + A_handle, x, y); \ + } \ + template <> \ + sycl::event trsv(sycl::queue &queue, uplo uplo_val, transpose transpose_val, diag diag_val, \ + matrix_handle_t A_handle, const FP_TYPE *x, FP_TYPE *y, \ + const std::vector &dependencies) { \ + auto libkey = get_device_id(queue); \ + return function_tables[libkey].trsv_usm##FP_SUFFIX( \ + queue, uplo_val, transpose_val, diag_val, A_handle, x, y, dependencies); \ + } + +FOR_EACH_FP_TYPE(DEFINE_TRSV) +#undef DEFINE_TRSV + +#define DEFINE_GEMM(FP_TYPE, FP_SUFFIX) \ + template <> \ + void gemm(sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, \ + transpose transpose_B, const FP_TYPE alpha, matrix_handle_t A_handle, \ + sycl::buffer &B, const std::int64_t columns, const std::int64_t ldb, \ + const FP_TYPE beta, sycl::buffer &C, const std::int64_t ldc) { \ + auto libkey = get_device_id(queue); \ + function_tables[libkey].gemm_buffer##FP_SUFFIX(queue, dense_matrix_layout, transpose_A, \ + transpose_B, alpha, A_handle, B, columns, \ + ldb, beta, C, ldc); \ + } \ + template <> \ + sycl::event gemm(sycl::queue &queue, layout dense_matrix_layout, transpose transpose_A, \ + transpose transpose_B, const FP_TYPE alpha, matrix_handle_t A_handle, \ + const FP_TYPE *B, const std::int64_t columns, const std::int64_t ldb, \ + const FP_TYPE beta, FP_TYPE *C, const std::int64_t ldc, \ + const std::vector &dependencies) { \ + auto libkey = get_device_id(queue); \ + return function_tables[libkey].gemm_usm##FP_SUFFIX( \ + queue, dense_matrix_layout, transpose_A, transpose_B, alpha, A_handle, B, columns, \ + ldb, beta, C, ldc, dependencies); \ + } + +FOR_EACH_FP_TYPE(DEFINE_GEMM) +#undef DEFINE_GEMM + +} // namespace oneapi::mkl::sparse diff --git a/tests/README.md b/tests/README.md index 54ff0f3c5..3a8346057 100644 --- a/tests/README.md +++ b/tests/README.md @@ -3,39 +3,9 @@ ## Overview Inside the `unit_tests` directory, there are domain-level directories which contain domain-specific tests, usually per function or per configuration. -## Steps -Functional testing is enabled by default via `Conan` and `CMake`, so all relevant functional tests will run automatically after the project is built successfully. +See [Building and Running Tests](https://oneapi-src.github.io/oneMKL/building_and_running_tests.html) documentation for more information about how to build and run the tests. -*Note: A set of `build options` define a `build configuration`. `CMake` builds and runs different set of tests depending on your `build configuration`. This is because `CMake` generates an export header file (config.hpp) for the selected build configuration. Check `/src/config.hpp.in` and `/src/CMakeLists.txt` for details. For details on how `CMake` performs export header generation, refer to [CMake documentation](https://cmake.org/cmake/help/v3.13/module/GenerateExportHeader.html).* - -You can re-run tests without re-building the entire project. - -#### The `CMake` Approach Works for any Generator -```bash -cmake --build . --target test -``` - -#### To use Generator-specific Commands: - -```bash -# For ninja -ninja test -``` - -```bash -# For GNU Makefiles -ctest -# Test filter use case - runs only Gpu specific tests -ctest -R Gpu -# Exclude filtering use case - excludes Cpu tests -ctest -E Cpu -``` - -For more `ctest` options, refer to [ctest manual page](https://cmake.org/cmake/help/v3.13/manual/ctest.1.html). - -## BLAS - -The tests in the level\ directories are for the corresponding level\ BLAS routines. [GoogleTest](https://github.com/google/googletest) is used as the unit-testing framework. +[GoogleTest](https://github.com/google/googletest) is used as the unit-testing framework. *Refer to `/deps/googletest/LICENSE` for GoogleTest license.* diff --git a/tests/unit_tests/CMakeLists.txt b/tests/unit_tests/CMakeLists.txt index 23c1f8582..e7fe8e110 100644 --- a/tests/unit_tests/CMakeLists.txt +++ b/tests/unit_tests/CMakeLists.txt @@ -56,6 +56,10 @@ set(lapack_TEST_LINK ${LAPACKE_LINK}) set(rng_TEST_LIST rng_statistics rng_service) +set(rng_DEVICE_TEST_LIST + rng_device_moments + rng_device_service +) set(rng_TEST_LINK "") @@ -65,12 +69,20 @@ set(dft_TEST_LIST set(dft_TEST_LINK "") +# Sparse BLAS config +set(sparse_blas_TEST_LIST + spblas_source) + +set(sparse_blas_TEST_LINK "") + foreach(domain ${TARGET_DOMAINS}) # Generate RT and CT test lists set(${domain}_TEST_LIST_RT ${${domain}_TEST_LIST}) set(${domain}_TEST_LIST_CT ${${domain}_TEST_LIST}) + set(${domain}_DEVICE_TEST_LIST_CT ${${domain}_DEVICE_TEST_LIST}) list(TRANSFORM ${domain}_TEST_LIST_RT APPEND _rt) list(TRANSFORM ${domain}_TEST_LIST_CT APPEND _ct) + list(TRANSFORM ${domain}_DEVICE_TEST_LIST_CT APPEND _ct) add_executable(test_main_${domain}_ct main_test.cpp) target_include_directories(test_main_${domain}_ct PUBLIC ${GTEST_INCLUDE_DIR}) @@ -84,7 +96,9 @@ foreach(domain ${TARGET_DOMAINS}) if(BUILD_SHARED_LIBS) add_executable(test_main_${domain}_rt main_test.cpp) target_include_directories(test_main_${domain}_rt PUBLIC ${GTEST_INCLUDE_DIR}) - target_compile_options(test_main_${domain}_rt PRIVATE -fsycl) + if(NOT ${ONEMKL_SYCL_IMPLEMENTATION} STREQUAL "hipsycl") + target_compile_options(test_main_${domain}_rt PRIVATE -fsycl) + endif() target_link_libraries(test_main_${domain}_rt PUBLIC gtest gtest_main @@ -124,6 +138,11 @@ foreach(domain ${TARGET_DOMAINS}) list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_netlib) endif() + if(domain STREQUAL "blas" AND ENABLE_PORTBLAS_BACKEND) + add_dependencies(test_main_${domain}_ct onemkl_${domain}_portblas) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_portblas) + endif() + if(domain STREQUAL "lapack" AND ENABLE_CUSOLVER_BACKEND) add_dependencies(test_main_${domain}_ct onemkl_${domain}_cusolver) list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_cusolver) @@ -144,6 +163,21 @@ foreach(domain ${TARGET_DOMAINS}) list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_rocrand) endif() + if(domain STREQUAL "dft" AND ENABLE_CUFFT_BACKEND) + add_dependencies(test_main_${domain}_ct onemkl_${domain}_cufft) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_${domain}_cufft) + endif() + + if(domain STREQUAL "dft" AND ENABLE_ROCFFT_BACKEND) + add_dependencies(test_main_${domain}_ct onemkl_dft_rocfft) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_dft_rocfft) + endif() + + if(domain STREQUAL "dft" AND ENABLE_PORTFFT_BACKEND) + add_dependencies(test_main_${domain}_ct onemkl_dft_portfft) + list(APPEND ONEMKL_LIBRARIES_${domain} onemkl_dft_portfft) + endif() + target_link_libraries(test_main_${domain}_ct PUBLIC gtest gtest_main @@ -152,7 +186,12 @@ foreach(domain ${TARGET_DOMAINS}) ${ONEMKL_LIBRARIES_${domain}} ONEMKL::SYCL::SYCL ${${domain}_TEST_LIST_CT} + ${${domain}_DEVICE_TEST_LIST_CT} ) + + if(NOT ${ONEMKL_SYCL_IMPLEMENTATION} STREQUAL "hipsycl") + target_link_options(test_main_${domain}_ct PUBLIC -fsycl-device-code-split=per_kernel) + endif() string(TOUPPER ${domain} DOMAIN_PREFIX) diff --git a/tests/unit_tests/blas/batch/axpy_batch_stride.cpp b/tests/unit_tests/blas/batch/axpy_batch_stride.cpp index 8bd3c489d..9bb1406ef 100644 --- a/tests/unit_tests/blas/batch/axpy_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/axpy_batch_stride.cpp @@ -105,7 +105,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, fp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::axpy_batch(main_queue, n, alpha, x_buffer, incx, stride_x, y_buffer, incy, stride_y, batch_size); @@ -119,14 +119,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, fp } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy_batch, n, - alpha, x_buffer, incx, stride_x, y_buffer, incy, stride_y, - batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy_batch, n, + alpha, x_buffer, incx, stride_x, y_buffer, incy, stride_y, + batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy_batch, n, alpha, - x_buffer, incx, stride_x, y_buffer, incy, stride_y, batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy_batch, n, + alpha, x_buffer, incx, stride_x, y_buffer, incy, stride_y, + batch_size); break; default: break; } @@ -149,7 +150,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, fp // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = true; for (i = 0; i < batch_size; i++) { good = good && check_equal_vector(y_accessor.get_pointer() + i * stride_y, @@ -172,6 +173,8 @@ TEST_P(AxpyBatchStrideTests, RealSinglePrecision) { } TEST_P(AxpyBatchStrideTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha = 2.0; EXPECT_TRUEORSKIP( test(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, alpha, 15)); @@ -192,6 +195,8 @@ TEST_P(AxpyBatchStrideTests, ComplexSinglePrecision) { } TEST_P(AxpyBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha = std::complex(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, alpha, 15)); @@ -203,7 +208,7 @@ TEST_P(AxpyBatchStrideTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(AxpyBatchStrideTestSuite, AxpyBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/axpy_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/axpy_batch_stride_usm.cpp index 243fbf23e..9ebc82abe 100644 --- a/tests/unit_tests/blas/batch/axpy_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/axpy_batch_stride_usm.cpp @@ -110,7 +110,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, fp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::axpy_batch( main_queue, n, alpha, &x[0], incx, stride_x, &y[0], incy, stride_y, batch_size, dependencies); @@ -125,15 +125,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, fp done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy_batch, n, - alpha, &x[0], incx, stride_x, &y[0], incy, stride_y, batch_size, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy_batch, n, + alpha, &x[0], incx, stride_x, &y[0], incy, stride_y, + batch_size, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy_batch, n, alpha, - &x[0], incx, stride_x, &y[0], incy, stride_y, batch_size, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy_batch, n, + alpha, &x[0], incx, stride_x, &y[0], incy, stride_y, + batch_size, dependencies); break; default: break; } @@ -179,6 +179,8 @@ TEST_P(AxpyBatchStrideUsmTests, RealSinglePrecision) { } TEST_P(AxpyBatchStrideUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha = 2.0; EXPECT_TRUEORSKIP( test(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, alpha, 15)); @@ -199,6 +201,8 @@ TEST_P(AxpyBatchStrideUsmTests, ComplexSinglePrecision) { } TEST_P(AxpyBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha = std::complex(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, alpha, 15)); @@ -210,7 +214,7 @@ TEST_P(AxpyBatchStrideUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(AxpyBatchStrideUsmTestSuite, AxpyBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/axpy_batch_usm.cpp b/tests/unit_tests/blas/batch/axpy_batch_usm.cpp index a160cf0f6..4dacf8ddb 100644 --- a/tests/unit_tests/blas/batch/axpy_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/axpy_batch_usm.cpp @@ -157,7 +157,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::axpy_batch( main_queue, n, alpha, (const fp **)x_array, incx, y_array, incy, group_count, group_size, dependencies); @@ -172,15 +172,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy_batch, n, - alpha, (const fp **)x_array, incx, y_array, incy, group_count, - group_size, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy_batch, n, + alpha, (const fp **)x_array, incx, y_array, incy, + group_count, group_size, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy_batch, n, alpha, - (const fp **)x_array, incx, y_array, incy, group_count, - group_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy_batch, n, + alpha, (const fp **)x_array, incx, y_array, incy, + group_count, group_size, dependencies); break; default: break; } @@ -259,6 +259,8 @@ TEST_P(AxpyBatchUsmTests, RealSinglePrecision) { } TEST_P(AxpyBatchUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -268,13 +270,15 @@ TEST_P(AxpyBatchUsmTests, ComplexSinglePrecision) { } TEST_P(AxpyBatchUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(AxpyBatchUsmTestSuite, AxpyBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/copy_batch_stride.cpp b/tests/unit_tests/blas/batch/copy_batch_stride.cpp index 8905baadd..a1da595f6 100644 --- a/tests/unit_tests/blas/batch/copy_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/copy_batch_stride.cpp @@ -104,7 +104,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::copy_batch(main_queue, n, x_buffer, incx, stride_x, y_buffer, incy, stride_y, batch_size); break; @@ -116,13 +116,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy_batch, n, - x_buffer, incx, stride_x, y_buffer, incy, stride_y, batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy_batch, n, + x_buffer, incx, stride_x, y_buffer, incy, stride_y, + batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy_batch, n, - x_buffer, incx, stride_x, y_buffer, incy, stride_y, batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy_batch, n, + x_buffer, incx, stride_x, y_buffer, incy, stride_y, + batch_size); break; default: break; } @@ -145,7 +147,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = true; for (i = 0; i < batch_size; i++) { good = good && check_equal_vector(y_accessor.get_pointer() + i * stride_y, @@ -164,6 +166,8 @@ TEST_P(CopyBatchStrideTests, RealSinglePrecision) { } TEST_P(CopyBatchStrideTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, 15)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), -2, -3, 15)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1, 1, 15)); @@ -179,6 +183,8 @@ TEST_P(CopyBatchStrideTests, ComplexSinglePrecision) { } TEST_P(CopyBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, 15)); EXPECT_TRUEORSKIP( @@ -189,7 +195,7 @@ TEST_P(CopyBatchStrideTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(CopyBatchStrideTestSuite, CopyBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/copy_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/copy_batch_stride_usm.cpp index aea0035a4..569293be1 100644 --- a/tests/unit_tests/blas/batch/copy_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/copy_batch_stride_usm.cpp @@ -109,7 +109,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::copy_batch(main_queue, n, &x[0], incx, stride_x, &y[0], incy, stride_y, batch_size, dependencies); @@ -124,14 +124,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy_batch, n, - &x[0], incx, stride_x, &y[0], incy, stride_y, batch_size, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy_batch, n, + &x[0], incx, stride_x, &y[0], incy, stride_y, batch_size, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy_batch, n, &x[0], - incx, stride_x, &y[0], incy, stride_y, batch_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy_batch, n, + &x[0], incx, stride_x, &y[0], incy, stride_y, batch_size, + dependencies); break; default: break; } @@ -174,6 +175,8 @@ TEST_P(CopyBatchStrideUsmTests, RealSinglePrecision) { } TEST_P(CopyBatchStrideUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha = 2.0; EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, 15)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), -2, -3, 15)); @@ -191,6 +194,8 @@ TEST_P(CopyBatchStrideUsmTests, ComplexSinglePrecision) { } TEST_P(CopyBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha = std::complex(2.0, -0.5); EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, 15)); @@ -202,7 +207,7 @@ TEST_P(CopyBatchStrideUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(CopyBatchStrideUsmTestSuite, CopyBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/copy_batch_usm.cpp b/tests/unit_tests/blas/batch/copy_batch_usm.cpp index 66bdf7f76..8cac23704 100644 --- a/tests/unit_tests/blas/batch/copy_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/copy_batch_usm.cpp @@ -153,7 +153,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::copy_batch( main_queue, n, (const fp **)x_array, incx, y_array, incy, group_count, group_size, dependencies); @@ -168,15 +168,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy_batch, n, - (const fp **)x_array, incx, y_array, incy, group_count, - group_size, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy_batch, n, + (const fp **)x_array, incx, y_array, incy, group_count, + group_size, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy_batch, n, - (const fp **)x_array, incx, y_array, incy, group_count, - group_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy_batch, n, + (const fp **)x_array, incx, y_array, incy, group_count, + group_size, dependencies); break; default: break; } @@ -253,6 +253,8 @@ TEST_P(CopyBatchUsmTests, RealSinglePrecision) { } TEST_P(CopyBatchUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -262,13 +264,15 @@ TEST_P(CopyBatchUsmTests, ComplexSinglePrecision) { } TEST_P(CopyBatchUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(CopyBatchUsmTestSuite, CopyBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/dgmm_batch_stride.cpp b/tests/unit_tests/blas/batch/dgmm_batch_stride.cpp index 0dca62764..bb642c3ee 100644 --- a/tests/unit_tests/blas/batch/dgmm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/dgmm_batch_stride.cpp @@ -121,7 +121,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::dgmm_batch(main_queue, left_right, m, n, A_buffer, lda, stride_a, x_buffer, incx, stride_x, C_buffer, ldc, stride_c, batch_size); @@ -135,15 +135,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dgmm_batch, - left_right, m, n, A_buffer, lda, stride_a, x_buffer, incx, - stride_x, C_buffer, ldc, stride_c, batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dgmm_batch, + left_right, m, n, A_buffer, lda, stride_a, x_buffer, incx, + stride_x, C_buffer, ldc, stride_c, batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dgmm_batch, left_right, - m, n, A_buffer, lda, stride_a, x_buffer, incx, stride_x, - C_buffer, ldc, stride_c, batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dgmm_batch, + left_right, m, n, A_buffer, lda, stride_a, x_buffer, incx, + stride_x, C_buffer, ldc, stride_c, batch_size); break; default: break; } @@ -166,7 +166,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = true; for (i = 0; i < batch_size; i++) { good = good && @@ -195,6 +195,8 @@ TEST_P(DgmmBatchStrideTests, RealSinglePrecision) { } TEST_P(DgmmBatchStrideTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::right, 2, 5)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -225,6 +227,8 @@ TEST_P(DgmmBatchStrideTests, ComplexSinglePrecision) { } TEST_P(DgmmBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::right, 2, 5)); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -241,7 +245,7 @@ TEST_P(DgmmBatchStrideTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(DgmmBatchStrideTestSuite, DgmmBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/dgmm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/dgmm_batch_stride_usm.cpp index e671f47f0..bb9cf0df3 100644 --- a/tests/unit_tests/blas/batch/dgmm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/dgmm_batch_stride_usm.cpp @@ -126,7 +126,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::dgmm_batch( main_queue, left_right, m, n, &A[0], lda, stride_a, &x[0], incx, stride_x, &C[0], ldc, stride_c, batch_size, dependencies); @@ -141,15 +141,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dgmm_batch, - left_right, m, n, &A[0], lda, stride_a, &x[0], incx, stride_x, - &C[0], ldc, stride_c, batch_size, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dgmm_batch, + left_right, m, n, &A[0], lda, stride_a, &x[0], incx, + stride_x, &C[0], ldc, stride_c, batch_size, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dgmm_batch, left_right, - m, n, &A[0], lda, stride_a, &x[0], incx, stride_x, &C[0], ldc, - stride_c, batch_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dgmm_batch, + left_right, m, n, &A[0], lda, stride_a, &x[0], incx, + stride_x, &C[0], ldc, stride_c, batch_size, dependencies); break; default: break; } @@ -200,6 +200,8 @@ TEST_P(DgmmBatchStrideUsmTests, RealSinglePrecision) { } TEST_P(DgmmBatchStrideUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::right, 2, 5)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -230,6 +232,8 @@ TEST_P(DgmmBatchStrideUsmTests, ComplexSinglePrecision) { } TEST_P(DgmmBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::right, 2, 5)); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -246,7 +250,7 @@ TEST_P(DgmmBatchStrideUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(DgmmBatchStrideUsmTestSuite, DgmmBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/dgmm_batch_usm.cpp b/tests/unit_tests/blas/batch/dgmm_batch_usm.cpp index d92dbb3e9..1f568580f 100644 --- a/tests/unit_tests/blas/batch/dgmm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/dgmm_batch_usm.cpp @@ -112,10 +112,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { idx = 0; for (i = 0; i < group_count; i++) { - size_a = (layout == oneapi::mkl::layout::column_major) ? lda[i] * n[i] : lda[i] * m[i]; + size_a = (layout == oneapi::mkl::layout::col_major) ? lda[i] * n[i] : lda[i] * m[i]; x_len = (left_right[i] == oneapi::mkl::side::R) ? n[i] : m[i]; size_x = 1 + (x_len - 1) * std::abs(incx[i]); - size_c = (layout == oneapi::mkl::layout::column_major) ? ldc[i] * n[i] : ldc[i] * m[i]; + size_c = (layout == oneapi::mkl::layout::col_major) ? ldc[i] * n[i] : ldc[i] * m[i]; for (j = 0; j < group_size[i]; j++) { a_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_a, *dev, cxt); x_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_x, *dev, cxt); @@ -187,7 +187,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::dgmm_batch( main_queue, &left_right[0], &m[0], &n[0], (const fp **)&a_array[0], &lda[0], (const fp **)&x_array[0], &incx[0], &c_array[0], &ldc[0], group_count, @@ -204,17 +204,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dgmm_batch, - &left_right[0], &m[0], &n[0], (const fp **)&a_array[0], &lda[0], - (const fp **)&x_array[0], &incx[0], &c_array[0], &ldc[0], - group_count, &group_size[0], dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dgmm_batch, + &left_right[0], &m[0], &n[0], (const fp **)&a_array[0], + &lda[0], (const fp **)&x_array[0], &incx[0], &c_array[0], + &ldc[0], group_count, &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dgmm_batch, - &left_right[0], &m[0], &n[0], (const fp **)&a_array[0], &lda[0], - (const fp **)&x_array[0], &incx[0], &c_array[0], &ldc[0], - group_count, &group_size[0], dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dgmm_batch, + &left_right[0], &m[0], &n[0], (const fp **)&a_array[0], + &lda[0], (const fp **)&x_array[0], &incx[0], &c_array[0], + &ldc[0], group_count, &group_size[0], dependencies); break; default: break; } @@ -292,6 +292,8 @@ TEST_P(DgmmBatchUsmTests, RealSinglePrecision) { } TEST_P(DgmmBatchUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -301,13 +303,15 @@ TEST_P(DgmmBatchUsmTests, ComplexSinglePrecision) { } TEST_P(DgmmBatchUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(DgmmBatchUsmTestSuite, DgmmBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp index 76e2bf2c4..5241cb822 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride.cpp @@ -47,13 +47,13 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Prepare data. int64_t m, n, k; int64_t lda, ldb, ldc; oneapi::mkl::transpose transa, transb; - fp alpha, beta; + Ts alpha, beta; int64_t i, tmp; batch_size = 1 + std::rand() % 20; @@ -63,14 +63,11 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { lda = std::max(m, k); ldb = std::max(n, k); ldc = std::max(m, n); - alpha = rand_scalar(); - beta = rand_scalar(); + alpha = rand_scalar(); + beta = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { - transa = (oneapi::mkl::transpose)(std::rand() % 2); - transb = (oneapi::mkl::transpose)(std::rand() % 2); - } - else { + if ((std::is_same>::value) || + (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa = oneapi::mkl::transpose::conjtrans; @@ -82,11 +79,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { else transb = (oneapi::mkl::transpose)tmp; } + else { + transa = (oneapi::mkl::transpose)(std::rand() % 2); + transb = (oneapi::mkl::transpose)(std::rand() % 2); + } int64_t stride_a, stride_b, stride_c; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * k : lda * m; stride_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * k; stride_c = ldc * n; @@ -99,8 +100,12 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { default: break; } - vector> A(stride_a * batch_size), B(stride_b * batch_size); - vector> C(stride_c * batch_size), C_ref(stride_c * batch_size); + vector> A(stride_a * batch_size); + vector> B(stride_b * batch_size); + vector> C(stride_c * batch_size), + C_cast_ref(stride_c * batch_size); + vector> A_ref(stride_a * batch_size), B_ref(stride_b * batch_size), + C_ref(stride_c * batch_size); for (i = 0; i < batch_size; i++) { rand_matrix(A.data() + stride_a * i, layout, transa, m, k, lda); @@ -108,10 +113,18 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { rand_matrix(C.data() + stride_c * i, layout, oneapi::mkl::transpose::nontrans, m, n, ldc); } - C_ref = C; + for (size_t i = 0; i < A.size(); ++i) { + A_ref[i] = A[i]; + } + for (size_t i = 0; i < B.size(); ++i) { + B_ref[i] = B[i]; + } + for (size_t i = 0; i < C.size(); ++i) { + C_ref[i] = C[i]; + } // Call reference GEMM_BATCH_STRIDE. - using fp_ref = typename ref_type_info::type; + using fp_ref = typename ref_type_info::type; int m_ref = (int)m; int n_ref = (int)n; int k_ref = (int)k; @@ -121,12 +134,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int batch_size_ref = (int)batch_size; for (i = 0; i < batch_size_ref; i++) { - ::gemm( - convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), - convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, - (const int *)&k_ref, (const fp_ref *)&alpha, (const fp_ref *)(A.data() + stride_a * i), - (const int *)&lda_ref, (const fp_ref *)(B.data() + stride_b * i), (const int *)&ldb_ref, - (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), (const int *)&ldc_ref); + ::gemm(convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), + convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, + (const int *)&k_ref, (const fp_ref *)&alpha, + (const fp_ref *)(A_ref.data() + stride_a * i), (const int *)&lda_ref, + (const fp_ref *)(B_ref.data() + stride_b * i), (const int *)&ldb_ref, + (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), + (const int *)&ldc_ref); } // Call DPC++ GEMM_BATCH_STRIDE. @@ -147,14 +161,14 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { queue main_queue(*dev, exception_handler); - buffer A_buffer(A.data(), range<1>(A.size())); - buffer B_buffer(B.data(), range<1>(B.size())); - buffer C_buffer(C.data(), range<1>(C.size())); + buffer A_buffer(A.data(), range<1>(A.size())); + buffer B_buffer(B.data(), range<1>(B.size())); + buffer C_buffer(C.data(), range<1>(C.size())); try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::gemm_batch( main_queue, transa, transb, m, n, k, alpha, A_buffer, lda, stride_a, B_buffer, ldb, stride_b, beta, C_buffer, ldc, stride_c, batch_size); @@ -168,19 +182,22 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_batch, transa, - transb, m, n, k, alpha, A_buffer, lda, stride_a, B_buffer, ldb, - stride_b, beta, C_buffer, ldc, stride_c, batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_batch, + transa, transb, m, n, k, alpha, A_buffer, lda, stride_a, + B_buffer, ldb, stride_b, beta, C_buffer, ldc, stride_c, + batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_batch, transa, - transb, m, n, k, alpha, A_buffer, lda, stride_a, B_buffer, ldb, - stride_b, beta, C_buffer, ldc, stride_c, batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_batch, + transa, transb, m, n, k, alpha, A_buffer, lda, stride_a, + B_buffer, ldb, stride_b, beta, C_buffer, ldc, stride_c, + batch_size); break; default: break; } #endif + main_queue.wait_and_throw(); } catch (exception const &e) { std::cout << "Caught synchronous SYCL exception during GEMM_BATCH_STRIDE:\n" @@ -198,11 +215,18 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. + int tol_scalar = 10; + int error_mag = tol_scalar * k; + if (std::is_same_v) + error_mag = 1; - auto C_accessor = C_buffer.template get_host_access(read_only); - bool good = - check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::column_major, - stride_c * batch_size, 1, stride_c * batch_size, 10 * k, std::cout); + for (size_t i = 0; i < C_ref.size(); ++i) { + C_cast_ref[i] = C_ref[i]; + } + auto C_accessor = C_buffer.get_host_access(read_only); + bool good = check_almost_equal_matrix(C_accessor, C_cast_ref, oneapi::mkl::layout::col_major, + stride_c * batch_size, 1, stride_c * batch_size, + error_mag, std::cout); return (int)good; } @@ -211,30 +235,54 @@ class GemmBatchStrideTests : public ::testing::TestWithParam> {}; TEST_P(GemmBatchStrideTests, RealHalfPrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideTests, HalfHalfFloatPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideTests, Int8Int8SinglePrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideTests, Int8Int8Int32Precision) { + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, RealDoublePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, std::complex>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } INSTANTIATE_TEST_SUITE_P(GemmBatchStrideTestSuite, GemmBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp index f75307350..97f2dd086 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_stride_usm.cpp @@ -47,7 +47,7 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -72,7 +72,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t m, n, k; int64_t lda, ldb, ldc; oneapi::mkl::transpose transa, transb; - fp alpha, beta; + Ts alpha, beta; int64_t i, tmp; @@ -83,13 +83,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { lda = std::max(m, k); ldb = std::max(n, k); ldc = std::max(m, n); - alpha = rand_scalar(); - beta = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { - transa = (oneapi::mkl::transpose)(std::rand() % 2); - transb = (oneapi::mkl::transpose)(std::rand() % 2); - } - else { + alpha = rand_scalar(); + beta = rand_scalar(); + if ((std::is_same>::value) || + (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa = oneapi::mkl::transpose::conjtrans; @@ -101,11 +98,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { else transb = (oneapi::mkl::transpose)tmp; } + else { + transa = (oneapi::mkl::transpose)(std::rand() % 2); + transb = (oneapi::mkl::transpose)(std::rand() % 2); + } int64_t stride_a, stride_b, stride_c; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * k : lda * m; stride_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * k; stride_c = ldc * n; @@ -118,18 +119,27 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { default: break; } - auto ua = usm_allocator(cxt, *dev); - vector A(ua), B(ua), C(ua), C_ref(ua); + auto ua = usm_allocator(cxt, *dev); + auto ub = usm_allocator(cxt, *dev); + auto uc = usm_allocator(cxt, *dev); + auto us = usm_allocator(cxt, *dev); + vector A(ua); + vector B(ub); + vector C(uc), C_cast_ref(uc); + vector A_ref(us), B_ref(us), C_ref(us); A.resize(stride_a * batch_size); B.resize(stride_b * batch_size); C.resize(stride_c * batch_size); + A_ref.resize(stride_c * batch_size); + B_ref.resize(stride_c * batch_size); C_ref.resize(stride_c * batch_size); + C_cast_ref.resize(stride_c * batch_size); - fp **a_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); - fp **b_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); - fp **c_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); - fp **c_ref_array = (fp **)oneapi::mkl::malloc_shared(64, sizeof(fp *) * batch_size, *dev, cxt); + Ta **a_array = (Ta **)oneapi::mkl::malloc_shared(64, sizeof(Ta *) * batch_size, *dev, cxt); + Tb **b_array = (Tb **)oneapi::mkl::malloc_shared(64, sizeof(Tb *) * batch_size, *dev, cxt); + Tc **c_array = (Tc **)oneapi::mkl::malloc_shared(64, sizeof(Tc *) * batch_size, *dev, cxt); + Ts **c_ref_array = (Ts **)oneapi::mkl::malloc_shared(64, sizeof(Ts *) * batch_size, *dev, cxt); if ((a_array == NULL) || (b_array == NULL) || (c_array == NULL) || (c_ref_array == NULL)) { std::cout << "Error cannot allocate arrays of pointers\n"; @@ -147,17 +157,21 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { c_ref_array[i] = &C_ref[i * stride_c]; } - rand_matrix(A, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_a * batch_size, 1, stride_a * batch_size); - rand_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_b * batch_size, 1, stride_b * batch_size); - rand_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size); - copy_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + copy_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, + stride_a * batch_size, 1, stride_a * batch_size, A_ref); + copy_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, + stride_b * batch_size, 1, stride_b * batch_size, B_ref); + copy_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size, C_ref); // Call reference GEMM_BATCH_STRIDE. - using fp_ref = typename ref_type_info::type; + using fp_ref = typename ref_type_info::type; int m_ref = (int)m; int n_ref = (int)n; int k_ref = (int)k; @@ -166,12 +180,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int ldc_ref = (int)ldc; int batch_size_ref = (int)batch_size; for (i = 0; i < batch_size_ref; i++) { - ::gemm( - convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), - convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, - (const int *)&k_ref, (const fp_ref *)&alpha, (const fp_ref *)(A.data() + stride_a * i), - (const int *)&lda_ref, (const fp_ref *)(B.data() + stride_b * i), (const int *)&ldb_ref, - (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), (const int *)&ldc_ref); + ::gemm(convert_to_cblas_layout(layout), convert_to_cblas_trans(transa), + convert_to_cblas_trans(transb), (const int *)&m_ref, (const int *)&n_ref, + (const int *)&k_ref, (const fp_ref *)&alpha, + (const fp_ref *)(A_ref.data() + stride_a * i), (const int *)&lda_ref, + (const fp_ref *)(B_ref.data() + stride_b * i), (const int *)&ldb_ref, + (const fp_ref *)&beta, (fp_ref *)(C_ref.data() + stride_c * i), + (const int *)&ldc_ref); } // Call DPC++ GEMM_BATCH_STRIDE. @@ -179,7 +194,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemm_batch( main_queue, transa, transb, m, n, k, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, beta, &C[0], ldc, stride_c, batch_size, dependencies); @@ -191,22 +206,24 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { break; default: break; } - done.wait(); + done.wait_and_throw(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_batch, transa, - transb, m, n, k, alpha, &A[0], lda, stride_a, &B[0], ldb, - stride_b, beta, &C[0], ldc, stride_c, batch_size, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_batch, + transa, transb, m, n, k, alpha, &A[0], lda, stride_a, &B[0], + ldb, stride_b, beta, &C[0], ldc, stride_c, batch_size, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_batch, transa, - transb, m, n, k, alpha, &A[0], lda, stride_a, &B[0], ldb, - stride_b, beta, &C[0], ldc, stride_c, batch_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_batch, + transa, transb, m, n, k, alpha, &A[0], lda, stride_a, &B[0], + ldb, stride_b, beta, &C[0], ldc, stride_c, batch_size, + dependencies); break; default: break; } - main_queue.wait(); + main_queue.wait_and_throw(); #endif } catch (exception const &e) { @@ -229,9 +246,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = - check_equal_matrix(C, C_ref, oneapi::mkl::layout::column_major, stride_c * batch_size, 1, - stride_c * batch_size, 10 * k, std::cout); + int tol_scalar = 10; + int error_mag = tol_scalar * k; + if (std::is_same_v) + error_mag = 1; + + for (size_t i = 0; i < C_ref.size(); ++i) { + C_cast_ref[i] = C_ref[i]; + } + bool good = check_almost_equal_matrix(C, C_cast_ref, oneapi::mkl::layout::col_major, + stride_c * batch_size, 1, stride_c * batch_size, + error_mag, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(b_array, cxt); @@ -245,30 +270,54 @@ class GemmBatchStrideUsmTests : public ::testing::TestWithParam> {}; TEST_P(GemmBatchStrideUsmTests, RealHalfPrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideUsmTests, HalfHalfFloatPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideUsmTests, Int8Int8SinglePrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchStrideUsmTests, Int8Int8Int32Precision) { + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, RealDoublePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, std::complex>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } INSTANTIATE_TEST_SUITE_P(GemmBatchStrideUsmTestSuite, GemmBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp index ed17c3134..a651f9ae3 100644 --- a/tests/unit_tests/blas/batch/gemm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemm_batch_usm.cpp @@ -47,7 +47,7 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -76,8 +76,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { auto uatranspose = usm_allocator(cxt, *dev); vector transa(uatranspose), transb(uatranspose); - auto uafp = usm_allocator(cxt, *dev); - vector alpha(uafp), beta(uafp); + auto uaTs = usm_allocator(cxt, *dev); + vector alpha(uaTs), beta(uaTs); m.resize(group_count); n.resize(group_count); @@ -104,13 +104,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { lda[i] = std::max(m[i], k[i]); ldb[i] = std::max(n[i], k[i]); ldc[i] = std::max(m[i], n[i]); - alpha[i] = rand_scalar(); - beta[i] = rand_scalar(); - if ((std::is_same::value) || (std::is_same::value)) { - transa[i] = (oneapi::mkl::transpose)(std::rand() % 2); - transb[i] = (oneapi::mkl::transpose)(std::rand() % 2); - } - else { + alpha[i] = rand_scalar(); + beta[i] = rand_scalar(); + if ((std::is_same>::value) || + (std::is_same>::value)) { tmp = std::rand() % 3; if (tmp == 2) transa[i] = oneapi::mkl::transpose::conjtrans; @@ -122,21 +119,33 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { else transb[i] = (oneapi::mkl::transpose)tmp; } + else { + transa[i] = (oneapi::mkl::transpose)(std::rand() % 2); + transb[i] = (oneapi::mkl::transpose)(std::rand() % 2); + } total_batch_count += group_size[i]; } - auto uafpp = usm_allocator(cxt, *dev); - vector a_array(uafpp), b_array(uafpp), c_array(uafpp), - c_ref_array(uafpp); + auto uaTap = usm_allocator(cxt, *dev); + auto uaTbp = usm_allocator(cxt, *dev); + auto uaTcp = usm_allocator(cxt, *dev); + auto uaTsp = usm_allocator(cxt, *dev); + vector a_array(uaTap); + vector b_array(uaTbp); + vector c_array(uaTcp), c_cast_ref_array(uaTcp); + vector a_ref_array(uaTsp), b_ref_array(uaTsp), c_ref_array(uaTsp); a_array.resize(total_batch_count); b_array.resize(total_batch_count); c_array.resize(total_batch_count); + a_ref_array.resize(total_batch_count); + b_ref_array.resize(total_batch_count); + c_cast_ref_array.resize(total_batch_count); c_ref_array.resize(total_batch_count); idx = 0; for (i = 0; i < group_count; i++) { switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda[i] * ((transa[i] == oneapi::mkl::transpose::nontrans) ? k[i] : m[i]); size_b = ldb[i] * ((transb[i] == oneapi::mkl::transpose::nontrans) ? n[i] : k[i]); size_c = ldc[i] * n[i]; @@ -149,13 +158,19 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { default: break; } for (j = 0; j < group_size[i]; j++) { - a_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_a, *dev, cxt); - b_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_b, *dev, cxt); - c_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_c, *dev, cxt); - c_ref_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_c, *dev, cxt); + a_array[idx] = (Ta *)oneapi::mkl::malloc_shared(64, sizeof(Ta) * size_a, *dev, cxt); + b_array[idx] = (Tb *)oneapi::mkl::malloc_shared(64, sizeof(Tb) * size_b, *dev, cxt); + c_array[idx] = (Tc *)oneapi::mkl::malloc_shared(64, sizeof(Tc) * size_c, *dev, cxt); + a_ref_array[idx] = (Ts *)oneapi::mkl::malloc_shared(64, sizeof(Ts) * size_a, *dev, cxt); + b_ref_array[idx] = (Ts *)oneapi::mkl::malloc_shared(64, sizeof(Ts) * size_b, *dev, cxt); + c_cast_ref_array[idx] = + (Tc *)oneapi::mkl::malloc_shared(64, sizeof(Tc) * size_c, *dev, cxt); + c_ref_array[idx] = (Ts *)oneapi::mkl::malloc_shared(64, sizeof(Ts) * size_c, *dev, cxt); rand_matrix(a_array[idx], layout, transa[i], m[i], k[i], lda[i]); rand_matrix(b_array[idx], layout, transb[i], k[i], n[i], ldb[i]); rand_matrix(c_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i]); + copy_matrix(a_array[idx], layout, transa[i], m[i], k[i], lda[i], a_ref_array[idx]); + copy_matrix(b_array[idx], layout, transb[i], k[i], n[i], ldb[i], b_ref_array[idx]); copy_matrix(c_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], ldc[i], c_ref_array[idx]); idx++; @@ -163,7 +178,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { } // Call reference GEMM_BATCH. - using fp_ref = typename ref_type_info::type; + using fp_ref = typename ref_type_info::type; int *m_ref = (int *)oneapi::mkl::aligned_alloc(64, sizeof(int) * group_count); int *n_ref = (int *)oneapi::mkl::aligned_alloc(64, sizeof(int) * group_count); int *k_ref = (int *)oneapi::mkl::aligned_alloc(64, sizeof(int) * group_count); @@ -196,6 +211,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { oneapi::mkl::free_shared(a_array[idx], cxt); oneapi::mkl::free_shared(b_array[idx], cxt); oneapi::mkl::free_shared(c_array[idx], cxt); + oneapi::mkl::free_shared(a_ref_array[idx], cxt); + oneapi::mkl::free_shared(b_ref_array[idx], cxt); + oneapi::mkl::free_shared(c_cast_ref_array[idx], cxt); oneapi::mkl::free_shared(c_ref_array[idx], cxt); idx++; } @@ -216,9 +234,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { for (j = 0; j < group_size_ref[i]; j++) { ::gemm(convert_to_cblas_layout(layout), transa_ref[i], transb_ref[i], (const int *)&m_ref[i], (const int *)&n_ref[i], (const int *)&k_ref[i], - (const fp_ref *)&alpha[i], (const fp_ref *)a_array[idx], - (const int *)&lda_ref[i], (const fp_ref *)b_array[idx], (const int *)&ldb_ref[i], - (const fp_ref *)&beta[i], (fp_ref *)c_ref_array[idx], (const int *)&ldc_ref[i]); + (const fp_ref *)&alpha[i], (const fp_ref *)a_ref_array[idx], + (const int *)&lda_ref[i], (const fp_ref *)b_ref_array[idx], + (const int *)&ldb_ref[i], (const fp_ref *)&beta[i], (fp_ref *)c_ref_array[idx], + (const int *)&ldc_ref[i]); idx++; } } @@ -228,40 +247,40 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemm_batch( main_queue, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], &ldb[0], &beta[0], + (const Ta **)&a_array[0], &lda[0], (const Tb **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: done = oneapi::mkl::blas::row_major::gemm_batch( main_queue, &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], &ldb[0], &beta[0], + (const Ta **)&a_array[0], &lda[0], (const Tb **)&b_array[0], &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, &group_size[0], dependencies); break; default: break; } - done.wait(); + done.wait_and_throw(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_batch, - &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], - &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, - &group_size[0], dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_batch, + &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], + (const Ta **)&a_array[0], &lda[0], (const Tb **)&b_array[0], + &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, + &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_batch, &transa[0], - &transb[0], &m[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], (const fp **)&b_array[0], - &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, - &group_size[0], dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_batch, + &transa[0], &transb[0], &m[0], &n[0], &k[0], &alpha[0], + (const Ta **)&a_array[0], &lda[0], (const Ta **)&b_array[0], + &ldb[0], &beta[0], &c_array[0], &ldc[0], group_count, + &group_size[0], dependencies); break; default: break; } - main_queue.wait(); + main_queue.wait_and_throw(); #endif } catch (exception const &e) { @@ -286,6 +305,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { oneapi::mkl::free_shared(a_array[idx], cxt); oneapi::mkl::free_shared(b_array[idx], cxt); oneapi::mkl::free_shared(c_array[idx], cxt); + oneapi::mkl::free_shared(a_ref_array[idx], cxt); + oneapi::mkl::free_shared(b_ref_array[idx], cxt); + oneapi::mkl::free_shared(c_cast_ref_array[idx], cxt); oneapi::mkl::free_shared(c_ref_array[idx], cxt); idx++; } @@ -299,11 +321,19 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { bool good = true; // Compare the results of reference implementation and DPC++ implementation. + int tol_scalar = 10; + idx = 0; for (i = 0; i < group_count; i++) { for (j = 0; j < group_size[i]; j++) { - good = good && check_equal_matrix(c_array[idx], c_ref_array[idx], layout, m[i], n[i], - ldc[i], 10 * k[i], std::cout); + int error_mag = tol_scalar * k[i]; + if (std::is_same_v) + error_mag = 1; + + copy_matrix(c_ref_array[idx], layout, oneapi::mkl::transpose::nontrans, m[i], n[i], + ldc[i], c_cast_ref_array[idx]); + good = good && check_almost_equal_matrix(c_array[idx], c_cast_ref_array[idx], layout, + m[i], n[i], ldc[i], error_mag, std::cout); idx++; } } @@ -322,6 +352,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { oneapi::mkl::free_shared(a_array[idx], cxt); oneapi::mkl::free_shared(b_array[idx], cxt); oneapi::mkl::free_shared(c_array[idx], cxt); + oneapi::mkl::free_shared(a_ref_array[idx], cxt); + oneapi::mkl::free_shared(b_ref_array[idx], cxt); + oneapi::mkl::free_shared(c_cast_ref_array[idx], cxt); oneapi::mkl::free_shared(c_ref_array[idx], cxt); idx++; } @@ -334,30 +367,54 @@ class GemmBatchUsmTests : public ::testing::TestWithParam> {}; TEST_P(GemmBatchUsmTests, RealHalfPrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchUsmTests, HalfHalfFloatPrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchUsmTests, Int8Int8SinglePrecision) { + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 5))); +} + +TEST_P(GemmBatchUsmTests, Int8Int8Int32Precision) { + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, RealSinglePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, RealDoublePrecision) { - EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, std::complex>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } TEST_P(GemmBatchUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( - test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); + (test, std::complex, std::complex, + std::complex>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5))); } INSTANTIATE_TEST_SUITE_P(GemmBatchUsmTestSuite, GemmBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/gemv_batch_stride.cpp b/tests/unit_tests/blas/batch/gemv_batch_stride.cpp index 25ec8554e..bd92f70ca 100644 --- a/tests/unit_tests/blas/batch/gemv_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/gemv_batch_stride.cpp @@ -136,7 +136,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::gemv_batch( main_queue, transa, m, n, alpha, A_buffer, lda, stride_a, x_buffer, incx, stride_x, beta, y_buffer, incy, stride_y, batch_size); @@ -150,15 +150,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv_batch, transa, - m, n, alpha, A_buffer, lda, stride_a, x_buffer, incx, stride_x, - beta, y_buffer, incy, stride_y, batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv_batch, + transa, m, n, alpha, A_buffer, lda, stride_a, x_buffer, + incx, stride_x, beta, y_buffer, incy, stride_y, batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv_batch, transa, m, - n, alpha, A_buffer, lda, stride_a, x_buffer, incx, stride_x, - beta, y_buffer, incy, stride_y, batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv_batch, + transa, m, n, alpha, A_buffer, lda, stride_a, x_buffer, + incx, stride_x, beta, y_buffer, incy, stride_y, batch_size); break; default: break; } @@ -181,7 +181,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = true; for (i = 0; i < batch_size; i++) { good = good && check_equal_vector(y_accessor.get_pointer() + i * stride_y, @@ -201,6 +201,8 @@ TEST_P(GemvBatchStrideTests, RealSinglePrecision) { } TEST_P(GemvBatchStrideTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, 5)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), -2, -3, 5)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1, 1, 5)); @@ -214,6 +216,8 @@ TEST_P(GemvBatchStrideTests, ComplexSinglePrecision) { } TEST_P(GemvBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, 5); test>(std::get<0>(GetParam()), std::get<1>(GetParam()), -2, -3, 5); @@ -222,7 +226,7 @@ TEST_P(GemvBatchStrideTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GemvBatchStrideTestSuite, GemvBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/gemv_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/gemv_batch_stride_usm.cpp index 232868405..d6eb47887 100644 --- a/tests/unit_tests/blas/batch/gemv_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/gemv_batch_stride_usm.cpp @@ -139,7 +139,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemv_batch( main_queue, transa, m, n, alpha, &A[0], lda, stride_a, &x[0], incx, stride_x, beta, &y[0], incy, stride_y, batch_size, dependencies); @@ -154,15 +154,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t incx, int64_t incy, in done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv_batch, transa, - m, n, alpha, &A[0], lda, stride_a, &x[0], incx, stride_x, beta, - &y[0], incy, stride_y, batch_size, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv_batch, + transa, m, n, alpha, &A[0], lda, stride_a, &x[0], incx, + stride_x, beta, &y[0], incy, stride_y, batch_size, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv_batch, transa, m, - n, alpha, &A[0], lda, stride_a, &x[0], incx, stride_x, beta, - &y[0], incy, stride_y, batch_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv_batch, + transa, m, n, alpha, &A[0], lda, stride_a, &x[0], incx, + stride_x, beta, &y[0], incy, stride_y, batch_size, + dependencies); break; default: break; } @@ -204,6 +206,8 @@ TEST_P(GemvBatchStrideUsmTests, RealSinglePrecision) { } TEST_P(GemvBatchStrideUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, 5)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), -2, -3, 5)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1, 1, 5)); @@ -217,6 +221,8 @@ TEST_P(GemvBatchStrideUsmTests, ComplexSinglePrecision) { } TEST_P(GemvBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 2, 3, 5); test>(std::get<0>(GetParam()), std::get<1>(GetParam()), -2, -3, 5); @@ -225,7 +231,7 @@ TEST_P(GemvBatchStrideUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GemvBatchStrideUsmTestSuite, GemvBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/gemv_batch_usm.cpp b/tests/unit_tests/blas/batch/gemv_batch_usm.cpp index 906ed9534..4ad661f5b 100644 --- a/tests/unit_tests/blas/batch/gemv_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/gemv_batch_usm.cpp @@ -129,7 +129,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { idx = 0; for (i = 0; i < group_count; i++) { - size_a = (layout == oneapi::mkl::layout::column_major) ? lda[i] * n[i] : lda[i] * m[i]; + size_a = (layout == oneapi::mkl::layout::col_major) ? lda[i] * n[i] : lda[i] * m[i]; x_len = (transa[i] == oneapi::mkl::transpose::nontrans) ? n[i] : m[i]; y_len = (transa[i] == oneapi::mkl::transpose::nontrans) ? m[i] : n[i]; size_x = 1 + (x_len - 1) * std::abs(incx[i]); @@ -205,7 +205,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemv_batch( main_queue, &transa[0], &m[0], &n[0], &alpha[0], (const fp **)&a_array[0], &lda[0], (const fp **)&x_array[0], &incx[0], &beta[0], &y_array[0], &incy[0], @@ -222,18 +222,19 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv_batch, - &transa[0], &m[0], &n[0], &alpha[0], (const fp **)&a_array[0], - &lda[0], (const fp **)&x_array[0], &incx[0], &beta[0], - &y_array[0], &incy[0], group_count, &group_size[0], - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv_batch, + &transa[0], &m[0], &n[0], &alpha[0], + (const fp **)&a_array[0], &lda[0], (const fp **)&x_array[0], + &incx[0], &beta[0], &y_array[0], &incy[0], group_count, + &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv_batch, &transa[0], - &m[0], &n[0], &alpha[0], (const fp **)&a_array[0], &lda[0], - (const fp **)&x_array[0], &incx[0], &beta[0], &y_array[0], - &incy[0], group_count, &group_size[0], dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv_batch, + &transa[0], &m[0], &n[0], &alpha[0], + (const fp **)&a_array[0], &lda[0], (const fp **)&x_array[0], + &incx[0], &beta[0], &y_array[0], &incy[0], group_count, + &group_size[0], dependencies); break; default: break; } @@ -312,6 +313,8 @@ TEST_P(GemvBatchUsmTests, RealSinglePrecision) { } TEST_P(GemvBatchUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -321,13 +324,15 @@ TEST_P(GemvBatchUsmTests, ComplexSinglePrecision) { } TEST_P(GemvBatchUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(GemvBatchUsmTestSuite, GemvBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/imatcopy_batch_stride.cpp b/tests/unit_tests/blas/batch/imatcopy_batch_stride.cpp index 079abc083..ac8bbb2b4 100644 --- a/tests/unit_tests/blas/batch/imatcopy_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/imatcopy_batch_stride.cpp @@ -66,7 +66,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t stride_a, stride_b, stride; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = lda * n; stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; stride = std::max(stride_a, stride_b); @@ -81,9 +81,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { vector> AB(stride * batch_size), AB_ref(stride * batch_size); - rand_matrix(AB.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(AB.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride * batch_size, 1, stride * batch_size); - copy_matrix(AB.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + copy_matrix(AB.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride * batch_size, 1, stride * batch_size, AB_ref.data()); // Call reference IMATCOPY_BATCH_STRIDE. @@ -120,7 +120,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::imatcopy_batch( main_queue, trans, m, n, alpha, AB_buffer, lda, ldb, stride, batch_size); break; @@ -132,13 +132,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy_batch, - trans, m, n, alpha, AB_buffer, lda, ldb, stride, batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy_batch, + trans, m, n, alpha, AB_buffer, lda, ldb, stride, + batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy_batch, trans, - m, n, alpha, AB_buffer, lda, ldb, stride, batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy_batch, + trans, m, n, alpha, AB_buffer, lda, ldb, stride, + batch_size); break; default: break; } @@ -161,8 +163,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Compare the results of reference implementation and DPC++ implementation. - auto AB_accessor = AB_buffer.template get_host_access(read_only); - bool good = check_equal_matrix(AB_accessor, AB_ref, oneapi::mkl::layout::column_major, + auto AB_accessor = AB_buffer.get_host_access(read_only); + bool good = check_equal_matrix(AB_accessor, AB_ref, oneapi::mkl::layout::col_major, stride * batch_size, 1, stride * batch_size, 10, std::cout); return (int)good; @@ -176,6 +178,8 @@ TEST_P(ImatcopyBatchStrideTests, RealSinglePrecision) { } TEST_P(ImatcopyBatchStrideTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -185,13 +189,15 @@ TEST_P(ImatcopyBatchStrideTests, ComplexSinglePrecision) { } TEST_P(ImatcopyBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(ImatcopyBatchStrideTestSuite, ImatcopyBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/imatcopy_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/imatcopy_batch_stride_usm.cpp index 3a5a85aec..b3099d309 100644 --- a/tests/unit_tests/blas/batch/imatcopy_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/imatcopy_batch_stride_usm.cpp @@ -85,7 +85,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t stride_a, stride_b, stride; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = lda * n; stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; stride = std::max(stride_a, stride_b); @@ -117,9 +117,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { ab_ref_array[i] = &AB_ref[i * stride]; } - rand_matrix(AB, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(AB, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride * batch_size, 1, stride * batch_size); - copy_matrix(AB, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + copy_matrix(AB, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride * batch_size, 1, stride * batch_size, AB_ref); // Call reference IMATCOPY_BATCH_STRIDE. @@ -136,7 +136,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::imatcopy_batch( main_queue, trans, m, n, alpha, &AB[0], lda, ldb, stride, batch_size, dependencies); @@ -151,14 +151,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy_batch, - trans, m, n, alpha, &AB[0], lda, ldb, stride, batch_size, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy_batch, + trans, m, n, alpha, &AB[0], lda, ldb, stride, batch_size, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy_batch, trans, - m, n, alpha, &AB[0], lda, ldb, stride, batch_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy_batch, + trans, m, n, alpha, &AB[0], lda, ldb, stride, batch_size, + dependencies); break; default: break; } @@ -183,8 +184,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal_matrix(AB, AB_ref, oneapi::mkl::layout::column_major, - stride * batch_size, 1, stride * batch_size, 10, std::cout); + bool good = check_equal_matrix(AB, AB_ref, oneapi::mkl::layout::col_major, stride * batch_size, + 1, stride * batch_size, 10, std::cout); oneapi::mkl::free_shared(ab_array, cxt); oneapi::mkl::free_shared(ab_ref_array, cxt); @@ -200,6 +201,8 @@ TEST_P(ImatcopyBatchStrideUsmTests, RealSinglePrecision) { } TEST_P(ImatcopyBatchStrideUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -209,13 +212,15 @@ TEST_P(ImatcopyBatchStrideUsmTests, ComplexSinglePrecision) { } TEST_P(ImatcopyBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(ImatcopyBatchStrideUsmTestSuite, ImatcopyBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/imatcopy_batch_usm.cpp b/tests/unit_tests/blas/batch/imatcopy_batch_usm.cpp index 967b0f9de..74c9881af 100644 --- a/tests/unit_tests/blas/batch/imatcopy_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/imatcopy_batch_usm.cpp @@ -112,7 +112,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { idx = 0; for (i = 0; i < group_count; i++) { switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda[i] * n[i]; size_b = (trans[i] == oneapi::mkl::transpose::nontrans) ? ldb[i] * n[i] : ldb[i] * m[i]; @@ -128,9 +128,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { for (j = 0; j < group_size[i]; j++) { ab_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size, *dev, cxt); ab_ref_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size, *dev, cxt); - rand_matrix(ab_array[idx], oneapi::mkl::layout::column_major, + rand_matrix(ab_array[idx], oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size, 1, size); - copy_matrix(ab_array[idx], oneapi::mkl::layout::column_major, + copy_matrix(ab_array[idx], oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size, 1, size, ab_ref_array[idx]); idx++; } @@ -155,7 +155,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::imatcopy_batch( main_queue, trans.data(), m.data(), n.data(), alpha.data(), ab_array.data(), lda.data(), ldb.data(), group_count, group_size.data(), dependencies); @@ -170,17 +170,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy_batch, - trans.data(), m.data(), n.data(), alpha.data(), ab_array.data(), - lda.data(), ldb.data(), group_count, group_size.data(), - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy_batch, + trans.data(), m.data(), n.data(), alpha.data(), + ab_array.data(), lda.data(), ldb.data(), group_count, + group_size.data(), dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy_batch, - trans.data(), m.data(), n.data(), alpha.data(), ab_array.data(), - lda.data(), ldb.data(), group_count, group_size.data(), - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy_batch, + trans.data(), m.data(), n.data(), alpha.data(), + ab_array.data(), lda.data(), ldb.data(), group_count, + group_size.data(), dependencies); break; default: break; } @@ -215,7 +215,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { idx = 0; for (i = 0; i < group_count; i++) { switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda[i] * n[i]; size_b = (trans[i] == oneapi::mkl::transpose::nontrans) ? ldb[i] * n[i] : ldb[i] * m[i]; @@ -229,9 +229,9 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { } size = std::max(size_a, size_b); for (j = 0; j < group_size[i]; j++) { - good = good && check_equal_matrix(ab_array[idx], ab_ref_array[idx], - oneapi::mkl::layout::column_major, size, 1, size, 10, - std::cout); + good = good && + check_equal_matrix(ab_array[idx], ab_ref_array[idx], + oneapi::mkl::layout::col_major, size, 1, size, 10, std::cout); idx++; } } @@ -256,6 +256,8 @@ TEST_P(ImatcopyBatchUsmTests, RealSinglePrecision) { } TEST_P(ImatcopyBatchUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -265,13 +267,15 @@ TEST_P(ImatcopyBatchUsmTests, ComplexSinglePrecision) { } TEST_P(ImatcopyBatchUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(ImatcopyBatchUsmTestSuite, ImatcopyBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/omatadd_batch_stride.cpp b/tests/unit_tests/blas/batch/omatadd_batch_stride.cpp index 222d918c6..cc20d0e3b 100644 --- a/tests/unit_tests/blas/batch/omatadd_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/omatadd_batch_stride.cpp @@ -70,7 +70,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t stride_a, stride_b, stride_c; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * n : lda * m; stride_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; stride_c = ldc * n; @@ -86,13 +86,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { vector> A(stride_a * batch_size), B(stride_b * batch_size), C(stride_c * batch_size), C_ref(stride_c * batch_size); - rand_matrix(A.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(A.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_a * batch_size, 1, stride_a * batch_size); - rand_matrix(B.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(B.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_b * batch_size, 1, stride_b * batch_size); - rand_matrix(C.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(C.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size); - copy_matrix(C.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + copy_matrix(C.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size, C_ref.data()); // Call reference OMATADD_BATCH_STRIDE. @@ -132,7 +132,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::omatadd_batch( main_queue, transa, transb, m, n, alpha, A_buffer, lda, stride_a, beta, B_buffer, ldb, stride_b, C_buffer, ldc, stride_c, batch_size); @@ -146,15 +146,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd_batch, - transa, transb, m, n, alpha, A_buffer, lda, stride_a, beta, - B_buffer, ldb, stride_b, C_buffer, ldc, stride_c, batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd_batch, + transa, transb, m, n, alpha, A_buffer, lda, stride_a, beta, + B_buffer, ldb, stride_b, C_buffer, ldc, stride_c, + batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd_batch, transa, - transb, m, n, alpha, A_buffer, lda, stride_a, beta, B_buffer, - ldb, stride_b, C_buffer, ldc, stride_c, batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd_batch, + transa, transb, m, n, alpha, A_buffer, lda, stride_a, beta, + B_buffer, ldb, stride_b, C_buffer, ldc, stride_c, + batch_size); break; default: break; } @@ -177,8 +179,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); - bool good = check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::column_major, + auto C_accessor = C_buffer.get_host_access(read_only); + bool good = check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, 1, stride_c * batch_size, 10, std::cout); return (int)good; @@ -192,6 +194,8 @@ TEST_P(OmataddBatchStrideTests, RealSinglePrecision) { } TEST_P(OmataddBatchStrideTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -201,13 +205,15 @@ TEST_P(OmataddBatchStrideTests, ComplexSinglePrecision) { } TEST_P(OmataddBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(OmataddBatchStrideTestSuite, OmataddBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/omatadd_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/omatadd_batch_stride_usm.cpp index a494cef48..7388084cb 100644 --- a/tests/unit_tests/blas/batch/omatadd_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/omatadd_batch_stride_usm.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t stride_a, stride_b, stride_c; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * n : lda * m; stride_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; stride_c = ldc * n; @@ -131,13 +131,13 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { c_ref_array[i] = &C_ref[i * stride_c]; } - rand_matrix(A, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_a * batch_size, 1, stride_a * batch_size); - rand_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_b * batch_size, 1, stride_b * batch_size); - rand_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size); - copy_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + copy_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size, C_ref); // Call reference OMATADD_BATCH_STRIDE. @@ -156,7 +156,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::omatadd_batch( main_queue, transa, transb, m, n, alpha, &A[0], lda, stride_a, beta, &B[0], ldb, stride_b, &C[0], ldc, stride_c, batch_size, dependencies); @@ -171,15 +171,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd_batch, - transa, transb, m, n, alpha, &A[0], lda, stride_a, beta, &B[0], - ldb, stride_b, &C[0], ldc, stride_c, batch_size, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd_batch, + transa, transb, m, n, alpha, &A[0], lda, stride_a, beta, + &B[0], ldb, stride_b, &C[0], ldc, stride_c, batch_size, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd_batch, transa, - transb, m, n, alpha, &A[0], lda, stride_a, beta, &B[0], ldb, - stride_b, &C[0], ldc, stride_c, batch_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd_batch, + transa, transb, m, n, alpha, &A[0], lda, stride_a, beta, + &B[0], ldb, stride_b, &C[0], ldc, stride_c, batch_size, + dependencies); break; default: break; } @@ -206,8 +208,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal_matrix(C, C_ref, oneapi::mkl::layout::column_major, - stride_c * batch_size, 1, stride_c * batch_size, 10, std::cout); + bool good = check_equal_matrix(C, C_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, + 1, stride_c * batch_size, 10, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(b_array, cxt); @@ -225,6 +227,8 @@ TEST_P(OmataddBatchStrideUsmTests, RealSinglePrecision) { } TEST_P(OmataddBatchStrideUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -234,13 +238,15 @@ TEST_P(OmataddBatchStrideUsmTests, ComplexSinglePrecision) { } TEST_P(OmataddBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(OmataddBatchStrideUsmTestSuite, OmataddBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/omatcopy_batch_stride.cpp b/tests/unit_tests/blas/batch/omatcopy_batch_stride.cpp index e07306023..d08329fc6 100644 --- a/tests/unit_tests/blas/batch/omatcopy_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/omatcopy_batch_stride.cpp @@ -67,7 +67,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t stride_a, stride_b; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = lda * n; stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; break; @@ -121,7 +121,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::omatcopy_batch(main_queue, trans, m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, stride_b, batch_size); @@ -135,15 +135,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy_batch, - trans, m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, - stride_b, batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy_batch, + trans, m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, + stride_b, batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy_batch, trans, - m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, stride_b, - batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy_batch, + trans, m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, + stride_b, batch_size); break; default: break; } @@ -166,8 +166,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Compare the results of reference implementation and DPC++ implementation. - auto B_accessor = B_buffer.template get_host_access(read_only); - bool good = check_equal_matrix(B_accessor, B_ref, oneapi::mkl::layout::column_major, + auto B_accessor = B_buffer.get_host_access(read_only); + bool good = check_equal_matrix(B_accessor, B_ref, oneapi::mkl::layout::col_major, stride_b * batch_size, 1, stride_b * batch_size, 10, std::cout); return (int)good; @@ -181,6 +181,8 @@ TEST_P(OmatcopyBatchStrideTests, RealSinglePrecision) { } TEST_P(OmatcopyBatchStrideTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -190,13 +192,15 @@ TEST_P(OmatcopyBatchStrideTests, ComplexSinglePrecision) { } TEST_P(OmatcopyBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(OmatcopyBatchStrideTestSuite, OmatcopyBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/omatcopy_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/omatcopy_batch_stride_usm.cpp index 4322d4e3e..7479b57db 100644 --- a/tests/unit_tests/blas/batch/omatcopy_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/omatcopy_batch_stride_usm.cpp @@ -87,7 +87,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t stride_a, stride_b; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = lda * n; stride_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; break; @@ -123,11 +123,11 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { b_ref_array[i] = &B_ref[i * stride_b]; } - rand_matrix(A, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_a * batch_size, 1, stride_a * batch_size); - rand_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_b * batch_size, 1, stride_b * batch_size); - copy_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + copy_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_b * batch_size, 1, stride_b * batch_size, B_ref); // Call reference OMATCOPY_BATCH_STRIDE. @@ -145,7 +145,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::omatcopy_batch( main_queue, trans, m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, batch_size, dependencies); @@ -160,15 +160,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy_batch, - trans, m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, - batch_size, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy_batch, + trans, m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, + stride_b, batch_size, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy_batch, trans, - m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, - batch_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy_batch, + trans, m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, + stride_b, batch_size, dependencies); break; default: break; } @@ -194,8 +194,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal_matrix(B, B_ref, oneapi::mkl::layout::column_major, - stride_b * batch_size, 1, stride_b * batch_size, 10, std::cout); + bool good = check_equal_matrix(B, B_ref, oneapi::mkl::layout::col_major, stride_b * batch_size, + 1, stride_b * batch_size, 10, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(b_array, cxt); @@ -212,6 +212,8 @@ TEST_P(OmatcopyBatchStrideUsmTests, RealSinglePrecision) { } TEST_P(OmatcopyBatchStrideUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -221,13 +223,15 @@ TEST_P(OmatcopyBatchStrideUsmTests, ComplexSinglePrecision) { } TEST_P(OmatcopyBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(OmatcopyBatchStrideUsmTestSuite, OmatcopyBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/omatcopy_batch_usm.cpp b/tests/unit_tests/blas/batch/omatcopy_batch_usm.cpp index 0516d0a97..7f1e4a103 100644 --- a/tests/unit_tests/blas/batch/omatcopy_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/omatcopy_batch_usm.cpp @@ -113,7 +113,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { idx = 0; for (i = 0; i < group_count; i++) { switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda[i] * n[i]; size_b = (trans[i] == oneapi::mkl::transpose::nontrans) ? ldb[i] * n[i] : ldb[i] * m[i]; @@ -129,11 +129,11 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { a_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_a, *dev, cxt); b_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_b, *dev, cxt); b_ref_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_b, *dev, cxt); - rand_matrix(a_array[idx], oneapi::mkl::layout::column_major, + rand_matrix(a_array[idx], oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_a, 1, size_a); - rand_matrix(b_array[idx], oneapi::mkl::layout::column_major, + rand_matrix(b_array[idx], oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_b, 1, size_b); - copy_matrix(b_array[idx], oneapi::mkl::layout::column_major, + copy_matrix(b_array[idx], oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_b, 1, size_b, b_ref_array[idx]); idx++; } @@ -158,7 +158,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::omatcopy_batch( main_queue, trans.data(), m.data(), n.data(), alpha.data(), (const fp **)a_array.data(), lda.data(), b_array.data(), ldb.data(), @@ -175,17 +175,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy_batch, - trans.data(), m.data(), n.data(), alpha.data(), - (const fp **)a_array.data(), lda.data(), b_array.data(), - ldb.data(), group_count, group_size.data(), dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy_batch, + trans.data(), m.data(), n.data(), alpha.data(), + (const fp **)a_array.data(), lda.data(), b_array.data(), + ldb.data(), group_count, group_size.data(), dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy_batch, - trans.data(), m.data(), n.data(), alpha.data(), - (const fp **)a_array.data(), lda.data(), b_array.data(), - ldb.data(), group_count, group_size.data(), dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy_batch, + trans.data(), m.data(), n.data(), alpha.data(), + (const fp **)a_array.data(), lda.data(), b_array.data(), + ldb.data(), group_count, group_size.data(), dependencies); break; default: break; } @@ -221,7 +221,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { idx = 0; for (i = 0; i < group_count; i++) { switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda[i] * n[i]; size_b = (trans[i] == oneapi::mkl::transpose::nontrans) ? ldb[i] * n[i] : ldb[i] * m[i]; @@ -235,8 +235,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { } for (j = 0; j < group_size[i]; j++) { good = good && check_equal_matrix(b_array[idx], b_ref_array[idx], - oneapi::mkl::layout::column_major, size_b, 1, size_b, - 10, std::cout); + oneapi::mkl::layout::col_major, size_b, 1, size_b, 10, + std::cout); idx++; } } @@ -262,6 +262,8 @@ TEST_P(OmatcopyBatchUsmTests, RealSinglePrecision) { } TEST_P(OmatcopyBatchUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -271,13 +273,15 @@ TEST_P(OmatcopyBatchUsmTests, ComplexSinglePrecision) { } TEST_P(OmatcopyBatchUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(OmatcopyBatchUsmTestSuite, OmatcopyBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/syrk_batch_stride.cpp b/tests/unit_tests/blas/batch/syrk_batch_stride.cpp index 94d532f6a..58dc4d7dc 100644 --- a/tests/unit_tests/blas/batch/syrk_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/syrk_batch_stride.cpp @@ -79,7 +79,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t stride_a, stride_c; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = (trans == oneapi::mkl::transpose::nontrans) ? lda * k : lda * n; stride_c = ldc * n; break; @@ -140,7 +140,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::syrk_batch(main_queue, upper_lower, trans, n, k, alpha, A_buffer, lda, stride_a, beta, C_buffer, ldc, stride_c, batch_size); @@ -154,15 +154,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk_batch, - upper_lower, trans, n, k, alpha, A_buffer, lda, stride_a, beta, - C_buffer, ldc, stride_c, batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk_batch, + upper_lower, trans, n, k, alpha, A_buffer, lda, stride_a, + beta, C_buffer, ldc, stride_c, batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk_batch, - upper_lower, trans, n, k, alpha, A_buffer, lda, stride_a, beta, - C_buffer, ldc, stride_c, batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk_batch, + upper_lower, trans, n, k, alpha, A_buffer, lda, stride_a, + beta, C_buffer, ldc, stride_c, batch_size); break; default: break; } @@ -185,10 +185,10 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = - check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::column_major, - stride_c * batch_size, 1, stride_c * batch_size, 10 * k, std::cout); + check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, + 1, stride_c * batch_size, 10 * k, std::cout); return (int)good; } @@ -201,6 +201,8 @@ TEST_P(SyrkBatchStrideTests, RealSinglePrecision) { } TEST_P(SyrkBatchStrideTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -210,13 +212,15 @@ TEST_P(SyrkBatchStrideTests, ComplexSinglePrecision) { } TEST_P(SyrkBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(SyrkBatchStrideTestSuite, SyrkBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/syrk_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/syrk_batch_stride_usm.cpp index dbfaca974..31aa09b79 100644 --- a/tests/unit_tests/blas/batch/syrk_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/syrk_batch_stride_usm.cpp @@ -98,7 +98,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { int64_t stride_a, stride_c; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: stride_a = (trans == oneapi::mkl::transpose::nontrans) ? lda * k : lda * n; stride_c = ldc * n; break; @@ -134,11 +134,11 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { c_ref_array[i] = &C_ref[i * stride_c]; } - rand_matrix(A, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_a * batch_size, 1, stride_a * batch_size); - rand_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + rand_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size); - copy_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, + copy_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, stride_c * batch_size, 1, stride_c * batch_size, C_ref); // Call reference SYRK_BATCH_STRIDE. @@ -161,7 +161,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::syrk_batch( main_queue, upper_lower, trans, n, k, alpha, &A[0], lda, stride_a, beta, &C[0], ldc, stride_c, batch_size, dependencies); @@ -176,15 +176,15 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk_batch, - upper_lower, trans, n, k, alpha, &A[0], lda, stride_a, beta, - &C[0], ldc, stride_c, batch_size, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk_batch, + upper_lower, trans, n, k, alpha, &A[0], lda, stride_a, beta, + &C[0], ldc, stride_c, batch_size, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk_batch, - upper_lower, trans, n, k, alpha, &A[0], lda, stride_a, beta, - &C[0], ldc, stride_c, batch_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk_batch, + upper_lower, trans, n, k, alpha, &A[0], lda, stride_a, beta, + &C[0], ldc, stride_c, batch_size, dependencies); break; default: break; } @@ -210,9 +210,8 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t batch_size) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = - check_equal_matrix(C, C_ref, oneapi::mkl::layout::column_major, stride_c * batch_size, 1, - stride_c * batch_size, 10 * k, std::cout); + bool good = check_equal_matrix(C, C_ref, oneapi::mkl::layout::col_major, stride_c * batch_size, + 1, stride_c * batch_size, 10 * k, std::cout); oneapi::mkl::free_shared(a_array, cxt); oneapi::mkl::free_shared(c_array, cxt); @@ -229,6 +228,8 @@ TEST_P(SyrkBatchStrideUsmTests, RealSinglePrecision) { } TEST_P(SyrkBatchStrideUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -238,13 +239,15 @@ TEST_P(SyrkBatchStrideUsmTests, ComplexSinglePrecision) { } TEST_P(SyrkBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(SyrkBatchStrideUsmTestSuite, SyrkBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/syrk_batch_usm.cpp b/tests/unit_tests/blas/batch/syrk_batch_usm.cpp index e63e352c1..36d0d6dd5 100644 --- a/tests/unit_tests/blas/batch/syrk_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/syrk_batch_usm.cpp @@ -127,7 +127,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { idx = 0; for (i = 0; i < group_count; i++) { switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda[i] * ((trans[i] == oneapi::mkl::transpose::nontrans) ? k[i] : n[i]); size_c = ldc[i] * n[i]; break; @@ -206,7 +206,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::syrk_batch( main_queue, &upper_lower[0], &trans[0], &n[0], &k[0], &alpha[0], (const fp **)&a_array[0], &lda[0], &beta[0], &c_array[0], &ldc[0], group_count, @@ -223,17 +223,17 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk_batch, - &upper_lower[0], &trans[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], &beta[0], &c_array[0], - &ldc[0], group_count, &group_size[0], dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk_batch, + &upper_lower[0], &trans[0], &n[0], &k[0], &alpha[0], + (const fp **)&a_array[0], &lda[0], &beta[0], &c_array[0], + &ldc[0], group_count, &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk_batch, - &upper_lower[0], &trans[0], &n[0], &k[0], &alpha[0], - (const fp **)&a_array[0], &lda[0], &beta[0], &c_array[0], - &ldc[0], group_count, &group_size[0], dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk_batch, + &upper_lower[0], &trans[0], &n[0], &k[0], &alpha[0], + (const fp **)&a_array[0], &lda[0], &beta[0], &c_array[0], + &ldc[0], group_count, &group_size[0], dependencies); break; default: break; } @@ -308,6 +308,8 @@ TEST_P(SyrkBatchUsmTests, RealSinglePrecision) { } TEST_P(SyrkBatchUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -317,13 +319,15 @@ TEST_P(SyrkBatchUsmTests, ComplexSinglePrecision) { } TEST_P(SyrkBatchUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(SyrkBatchUsmTestSuite, SyrkBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/trsm_batch_stride.cpp b/tests/unit_tests/blas/batch/trsm_batch_stride.cpp index 936911c10..cde6aa367 100644 --- a/tests/unit_tests/blas/batch/trsm_batch_stride.cpp +++ b/tests/unit_tests/blas/batch/trsm_batch_stride.cpp @@ -86,7 +86,7 @@ int test(device *dev, oneapi::mkl::layout layout) { stride_a = (left_right == oneapi::mkl::side::left) ? lda * m : lda * n; switch (layout) { - case oneapi::mkl::layout::column_major: stride_b = ldb * n; break; + case oneapi::mkl::layout::col_major: stride_b = ldb * n; break; case oneapi::mkl::layout::row_major: stride_b = ldb * m; break; default: break; } @@ -146,7 +146,7 @@ int test(device *dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::trsm_batch( main_queue, left_right, upper_lower, trans, unit_nonunit, m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, stride_b, batch_size); @@ -160,15 +160,17 @@ int test(device *dev, oneapi::mkl::layout layout) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm_batch, - left_right, upper_lower, trans, unit_nonunit, m, n, alpha, - A_buffer, lda, stride_a, B_buffer, ldb, stride_b, batch_size); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm_batch, + left_right, upper_lower, trans, unit_nonunit, m, n, alpha, + A_buffer, lda, stride_a, B_buffer, ldb, stride_b, + batch_size); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm_batch, left_right, - upper_lower, trans, unit_nonunit, m, n, alpha, A_buffer, lda, - stride_a, B_buffer, ldb, stride_b, batch_size); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm_batch, + left_right, upper_lower, trans, unit_nonunit, m, n, alpha, + A_buffer, lda, stride_a, B_buffer, ldb, stride_b, + batch_size); break; default: break; } @@ -190,10 +192,10 @@ int test(device *dev, oneapi::mkl::layout layout) { } // Compare the results of reference implementation and DPC++ implementation. - auto B_accessor = B_buffer.template get_host_access(read_only); + auto B_accessor = B_buffer.get_host_access(read_only); bool good = - check_equal_trsm_matrix(B_accessor, B_ref, oneapi::mkl::layout::column_major, total_size_b, - 1, total_size_b, 10 * std::max(m, n), std::cout); + check_equal_trsm_matrix(B_accessor, B_ref, oneapi::mkl::layout::col_major, total_size_b, 1, + total_size_b, 10 * std::max(m, n), std::cout); return (int)good; } @@ -206,6 +208,8 @@ TEST_P(TrsmBatchStrideTests, RealSinglePrecision) { } TEST_P(TrsmBatchStrideTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -214,12 +218,14 @@ TEST_P(TrsmBatchStrideTests, ComplexSinglePrecision) { } TEST_P(TrsmBatchStrideTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); } INSTANTIATE_TEST_SUITE_P(TrsmBatchStrideTestSuite, TrsmBatchStrideTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/trsm_batch_stride_usm.cpp b/tests/unit_tests/blas/batch/trsm_batch_stride_usm.cpp index e24c02066..d99836f87 100644 --- a/tests/unit_tests/blas/batch/trsm_batch_stride_usm.cpp +++ b/tests/unit_tests/blas/batch/trsm_batch_stride_usm.cpp @@ -104,7 +104,7 @@ int test(device *dev, oneapi::mkl::layout layout) { int64_t total_size_b; stride_a = (left_right == oneapi::mkl::side::left) ? lda * m : lda * n; - stride_b = (layout == oneapi::mkl::layout::column_major) ? ldb * n : ldb * m; + stride_b = (layout == oneapi::mkl::layout::col_major) ? ldb * n : ldb * m; total_size_b = batch_size * stride_b; @@ -123,8 +123,8 @@ int test(device *dev, oneapi::mkl::layout layout) { rand_matrix(&B[stride_b * i], layout, oneapi::mkl::transpose::nontrans, m, n, ldb); } - copy_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, - total_size_b, 1, total_size_b, B_ref); + copy_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, total_size_b, + 1, total_size_b, B_ref); // Call reference TRSM_BATCH_STRIDE. using fp_ref = typename ref_type_info::type; @@ -148,7 +148,7 @@ int test(device *dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::trsm_batch( main_queue, left_right, upper_lower, trans, unit_nonunit, m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, batch_size, dependencies); @@ -163,15 +163,17 @@ int test(device *dev, oneapi::mkl::layout layout) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm_batch, - left_right, upper_lower, trans, unit_nonunit, m, n, alpha, &A[0], - lda, stride_a, &B[0], ldb, stride_b, batch_size, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm_batch, + left_right, upper_lower, trans, unit_nonunit, m, n, alpha, + &A[0], lda, stride_a, &B[0], ldb, stride_b, batch_size, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm_batch, left_right, - upper_lower, trans, unit_nonunit, m, n, alpha, &A[0], lda, - stride_a, &B[0], ldb, stride_b, batch_size, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm_batch, + left_right, upper_lower, trans, unit_nonunit, m, n, alpha, + &A[0], lda, stride_a, &B[0], ldb, stride_b, batch_size, + dependencies); break; default: break; } @@ -194,8 +196,8 @@ int test(device *dev, oneapi::mkl::layout layout) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal_trsm_matrix(B, B_ref, oneapi::mkl::layout::column_major, total_size_b, - 1, total_size_b, 10 * std::max(m, n), std::cout); + bool good = check_equal_trsm_matrix(B, B_ref, oneapi::mkl::layout::col_major, total_size_b, 1, + total_size_b, 10 * std::max(m, n), std::cout); return (int)good; } @@ -208,6 +210,8 @@ TEST_P(TrsmBatchStrideUsmTests, RealSinglePrecision) { } TEST_P(TrsmBatchStrideUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -216,12 +220,14 @@ TEST_P(TrsmBatchStrideUsmTests, ComplexSinglePrecision) { } TEST_P(TrsmBatchStrideUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); } INSTANTIATE_TEST_SUITE_P(TrsmBatchStrideUsmTestSuite, TrsmBatchStrideUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/batch/trsm_batch_usm.cpp b/tests/unit_tests/blas/batch/trsm_batch_usm.cpp index 4b9dd56db..747f59433 100644 --- a/tests/unit_tests/blas/batch/trsm_batch_usm.cpp +++ b/tests/unit_tests/blas/batch/trsm_batch_usm.cpp @@ -139,7 +139,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { for (i = 0; i < group_count; i++) { size_a = lda[i] * (left_right[i] == oneapi::mkl::side::left ? m[i] : n[i]); Arank = left_right[i] == oneapi::mkl::side::left ? m[i] : n[i]; - size_b = ldb[i] * ((layout == oneapi::mkl::layout::column_major) ? n[i] : m[i]); + size_b = ldb[i] * ((layout == oneapi::mkl::layout::col_major) ? n[i] : m[i]); for (j = 0; j < group_size[i]; j++) { a_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_a, *dev, cxt); b_array[idx] = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp) * size_b, *dev, cxt); @@ -218,7 +218,7 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::trsm_batch( main_queue, &left_right[0], &upper_lower[0], &trans[0], &unit_nonunit[0], &m[0], &n[0], &alpha[0], (const fp **)&a_array[0], &lda[0], &b_array[0], &ldb[0], @@ -235,17 +235,19 @@ int test(device *dev, oneapi::mkl::layout layout, int64_t group_count) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm_batch, - &left_right[0], &upper_lower[0], &trans[0], &unit_nonunit[0], - &m[0], &n[0], &alpha[0], (const fp **)&a_array[0], &lda[0], - &b_array[0], &ldb[0], group_count, &group_size[0], dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm_batch, + &left_right[0], &upper_lower[0], &trans[0], + &unit_nonunit[0], &m[0], &n[0], &alpha[0], + (const fp **)&a_array[0], &lda[0], &b_array[0], &ldb[0], + group_count, &group_size[0], dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm_batch, - &left_right[0], &upper_lower[0], &trans[0], &unit_nonunit[0], - &m[0], &n[0], &alpha[0], (const fp **)&a_array[0], &lda[0], - &b_array[0], &ldb[0], group_count, &group_size[0], dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm_batch, + &left_right[0], &upper_lower[0], &trans[0], + &unit_nonunit[0], &m[0], &n[0], &alpha[0], + (const fp **)&a_array[0], &lda[0], &b_array[0], &ldb[0], + group_count, &group_size[0], dependencies); break; default: break; } @@ -324,6 +326,8 @@ TEST_P(TrsmBatchUsmTests, RealSinglePrecision) { } TEST_P(TrsmBatchUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } @@ -333,13 +337,15 @@ TEST_P(TrsmBatchUsmTests, ComplexSinglePrecision) { } TEST_P(TrsmBatchUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 5)); } INSTANTIATE_TEST_SUITE_P(TrsmBatchUsmTestSuite, TrsmBatchUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/CMakeLists.txt b/tests/unit_tests/blas/extensions/CMakeLists.txt index 43ea3b85d..af58e5076 100644 --- a/tests/unit_tests/blas/extensions/CMakeLists.txt +++ b/tests/unit_tests/blas/extensions/CMakeLists.txt @@ -18,7 +18,7 @@ #=============================================================================== # Build object from all test sources -set(EXTENSIONS_SOURCES "gemm_bias.cpp" "gemmt.cpp" "gemm_bias_usm.cpp" "gemmt_usm.cpp" "omatcopy.cpp" "omatcopy_usm.cpp" "imatcopy.cpp" "imatcopy_usm.cpp" "omatadd.cpp" "omatadd_usm.cpp") +set(EXTENSIONS_SOURCES "gemm_bias.cpp" "gemmt.cpp" "gemm_bias_usm.cpp" "gemmt_usm.cpp" "omatcopy.cpp" "omatcopy_usm.cpp" "imatcopy.cpp" "imatcopy_usm.cpp" "omatadd.cpp" "omatadd_usm.cpp" "omatcopy2.cpp" "omatcopy2_usm.cpp") if(BUILD_SHARED_LIBS) add_library(blas_extensions_rt OBJECT ${EXTENSIONS_SOURCES}) diff --git a/tests/unit_tests/blas/extensions/gemm_bias.cpp b/tests/unit_tests/blas/extensions/gemm_bias.cpp index f5d09eb2f..c6e99e829 100644 --- a/tests/unit_tests/blas/extensions/gemm_bias.cpp +++ b/tests/unit_tests/blas/extensions/gemm_bias.cpp @@ -63,14 +63,11 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, rand_matrix(B, layout, transb, k, n, ldb); rand_matrix(C, layout, oneapi::mkl::transpose::nontrans, m, n, ldc); if (offsetc == oneapi::mkl::offset::fix) - rand_matrix(co, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, 1, 1, - 1); + rand_matrix(co, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, 1, 1, 1); if (offsetc == oneapi::mkl::offset::column) - rand_matrix(co, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, m, 1, - m); + rand_matrix(co, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, m, 1, m); if (offsetc == oneapi::mkl::offset::row) - rand_matrix(co, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, n, 1, - n); + rand_matrix(co, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, n, 1, n); C_ref = C; @@ -115,7 +112,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::gemm_bias(main_queue, transa, transb, offsetc, m, n, k, alpha, A_buffer, lda, ao, B_buffer, ldb, bo, beta, C_buffer, ldc, CO_buffer); @@ -129,15 +126,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_bias, transa, - transb, offsetc, m, n, k, alpha, A_buffer, lda, ao, B_buffer, - ldb, bo, beta, C_buffer, ldc, CO_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_bias, + transa, transb, offsetc, m, n, k, alpha, A_buffer, lda, ao, + B_buffer, ldb, bo, beta, C_buffer, ldc, CO_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_bias, transa, - transb, offsetc, m, n, k, alpha, A_buffer, lda, ao, B_buffer, - ldb, bo, beta, C_buffer, ldc, CO_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_bias, transa, + transb, offsetc, m, n, k, alpha, A_buffer, lda, ao, + B_buffer, ldb, bo, beta, C_buffer, ldc, CO_buffer); break; default: break; } @@ -158,7 +155,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, } // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_ref, layout, m, n, ldc, 10 * k, std::cout); return (int)good; @@ -381,7 +378,7 @@ TEST_P(GemmBiasTests, Uint8Uint8Int32Precision) { INSTANTIATE_TEST_SUITE_P(GemmBiasTestSuite, GemmBiasTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/gemm_bias_usm.cpp b/tests/unit_tests/blas/extensions/gemm_bias_usm.cpp index 8354ac530..908eed909 100644 --- a/tests/unit_tests/blas/extensions/gemm_bias_usm.cpp +++ b/tests/unit_tests/blas/extensions/gemm_bias_usm.cpp @@ -85,14 +85,11 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, rand_matrix(B, layout, transb, k, n, ldb); rand_matrix(C, layout, oneapi::mkl::transpose::nontrans, m, n, ldc); if (offsetc == oneapi::mkl::offset::fix) - rand_matrix(co, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, 1, 1, - 1); + rand_matrix(co, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, 1, 1, 1); if (offsetc == oneapi::mkl::offset::column) - rand_matrix(co, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, m, 1, - m); + rand_matrix(co, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, m, 1, m); if (offsetc == oneapi::mkl::offset::row) - rand_matrix(co, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, n, 1, - n); + rand_matrix(co, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, n, 1, n); C_ref.resize(C.size()); for (int i = 0; i < C.size(); i++) @@ -118,7 +115,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemm_bias( main_queue, transa, transb, offsetc, m, n, k, alpha, A.data(), lda, ao, B.data(), ldb, bo, beta, C.data(), ldc, co.data(), dependencies); @@ -133,15 +130,17 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_bias, transa, - transb, offsetc, m, n, k, alpha, A.data(), lda, ao, B.data(), - ldb, bo, beta, C.data(), ldc, co.data(), dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm_bias, + transa, transb, offsetc, m, n, k, alpha, A.data(), lda, ao, + B.data(), ldb, bo, beta, C.data(), ldc, co.data(), + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_bias, transa, - transb, offsetc, m, n, k, alpha, A.data(), lda, ao, B.data(), - ldb, bo, beta, C.data(), ldc, co.data(), dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm_bias, transa, + transb, offsetc, m, n, k, alpha, A.data(), lda, ao, + B.data(), ldb, bo, beta, C.data(), ldc, co.data(), + dependencies); break; default: break; } @@ -385,7 +384,7 @@ TEST_P(GemmBiasUsmTests, Uint8Uint8Int32Precision) { INSTANTIATE_TEST_SUITE_P(GemmBiasUsmTestSuite, GemmBiasUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/gemmt.cpp b/tests/unit_tests/blas/extensions/gemmt.cpp index e9c266179..228a85d33 100644 --- a/tests/unit_tests/blas/extensions/gemmt.cpp +++ b/tests/unit_tests/blas/extensions/gemmt.cpp @@ -94,7 +94,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::gemmt(main_queue, upper_lower, transa, transb, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, ldc); @@ -108,15 +108,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemmt, upper_lower, - transa, transb, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, - C_buffer, ldc); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemmt, + upper_lower, transa, transb, n, k, alpha, A_buffer, lda, + B_buffer, ldb, beta, C_buffer, ldc); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemmt, upper_lower, - transa, transb, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, - C_buffer, ldc); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemmt, + upper_lower, transa, transb, n, k, alpha, A_buffer, lda, + B_buffer, ldb, beta, C_buffer, ldc); break; default: break; } @@ -136,7 +136,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_ref, layout, upper_lower, n, n, ldc, 10 * k, std::cout); @@ -184,6 +184,8 @@ TEST_P(GemmtTests, RealSinglePrecision) { } TEST_P(GemmtTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -298,6 +300,8 @@ TEST_P(GemmtTests, ComplexSinglePrecision) { } TEST_P(GemmtTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0); std::complex beta(3.0); EXPECT_TRUEORSKIP(test>( @@ -376,7 +380,7 @@ TEST_P(GemmtTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GemmtTestSuite, GemmtTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/gemmt_usm.cpp b/tests/unit_tests/blas/extensions/gemmt_usm.cpp index fde2b724f..dac300ae2 100644 --- a/tests/unit_tests/blas/extensions/gemmt_usm.cpp +++ b/tests/unit_tests/blas/extensions/gemmt_usm.cpp @@ -94,7 +94,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemmt( main_queue, upper_lower, transa, transb, n, k, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc, dependencies); @@ -109,15 +109,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemmt, upper_lower, - transa, transb, n, k, alpha, A.data(), lda, B.data(), ldb, beta, - C.data(), ldc, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemmt, + upper_lower, transa, transb, n, k, alpha, A.data(), lda, + B.data(), ldb, beta, C.data(), ldc, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemmt, upper_lower, - transa, transb, n, k, alpha, A.data(), lda, B.data(), ldb, beta, - C.data(), ldc, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemmt, + upper_lower, transa, transb, n, k, alpha, A.data(), lda, + B.data(), ldb, beta, C.data(), ldc, dependencies); break; default: break; } @@ -184,6 +184,8 @@ TEST_P(GemmtUsmTests, RealSinglePrecision) { } TEST_P(GemmtUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -298,6 +300,8 @@ TEST_P(GemmtUsmTests, ComplexSinglePrecision) { } TEST_P(GemmtUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0); std::complex beta(3.0); EXPECT_TRUEORSKIP(test>( @@ -376,7 +380,7 @@ TEST_P(GemmtUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GemmtUsmTestSuite, GemmtUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/imatcopy.cpp b/tests/unit_tests/blas/extensions/imatcopy.cpp index 5b3e10216..e21702775 100644 --- a/tests/unit_tests/blas/extensions/imatcopy.cpp +++ b/tests/unit_tests/blas/extensions/imatcopy.cpp @@ -65,7 +65,7 @@ int test(device *dev, oneapi::mkl::layout layout) { int64_t size_a, size_b, size; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda * n; size_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; break; @@ -79,10 +79,10 @@ int test(device *dev, oneapi::mkl::layout layout) { vector> AB(size), AB_ref(size); - rand_matrix(AB, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size, 1, + rand_matrix(AB, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size, 1, size); - copy_matrix(AB, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size, 1, - size, AB_ref); + copy_matrix(AB, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size, 1, size, + AB_ref); // Call reference IMATCOPY. int m_ref = (int)m; @@ -114,7 +114,7 @@ int test(device *dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::imatcopy(main_queue, trans, m, n, alpha, AB_buffer, lda, ldb); break; @@ -126,13 +126,13 @@ int test(device *dev, oneapi::mkl::layout layout) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy, trans, m, - n, alpha, AB_buffer, lda, ldb); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy, + trans, m, n, alpha, AB_buffer, lda, ldb); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy, trans, m, n, - alpha, AB_buffer, lda, ldb); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy, trans, + m, n, alpha, AB_buffer, lda, ldb); break; default: break; } @@ -154,8 +154,8 @@ int test(device *dev, oneapi::mkl::layout layout) { // Compare the results of reference implementation and DPC++ implementation. - auto AB_accessor = AB_buffer.template get_host_access(read_only); - bool good = check_equal_matrix(AB_accessor, AB_ref, oneapi::mkl::layout::column_major, size, 1, + auto AB_accessor = AB_buffer.get_host_access(read_only); + bool good = check_equal_matrix(AB_accessor, AB_ref, oneapi::mkl::layout::col_major, size, 1, size, 10, std::cout); return (int)good; @@ -169,6 +169,8 @@ TEST_P(ImatcopyTests, RealSinglePrecision) { } TEST_P(ImatcopyTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -177,12 +179,14 @@ TEST_P(ImatcopyTests, ComplexSinglePrecision) { } TEST_P(ImatcopyTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); } INSTANTIATE_TEST_SUITE_P(ImatcopyTestSuite, ImatcopyTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/imatcopy_usm.cpp b/tests/unit_tests/blas/extensions/imatcopy_usm.cpp index 2d8ea88fc..dc3d43d2e 100644 --- a/tests/unit_tests/blas/extensions/imatcopy_usm.cpp +++ b/tests/unit_tests/blas/extensions/imatcopy_usm.cpp @@ -85,7 +85,7 @@ int test(device *dev, oneapi::mkl::layout layout) { int64_t size_a, size_b, size; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda * n; size_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; break; @@ -103,10 +103,10 @@ int test(device *dev, oneapi::mkl::layout layout) { AB.resize(size); AB_ref.resize(size); - rand_matrix(AB, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size, 1, + rand_matrix(AB, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size, 1, size); - copy_matrix(AB, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size, 1, - size, AB_ref); + copy_matrix(AB, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size, 1, size, + AB_ref); // Call reference IMATCOPY. int m_ref = (int)m; @@ -119,7 +119,7 @@ int test(device *dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::imatcopy(main_queue, trans, m, n, alpha, &AB[0], lda, ldb, dependencies); break; @@ -132,13 +132,13 @@ int test(device *dev, oneapi::mkl::layout layout) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy, trans, m, - n, alpha, &AB[0], lda, ldb, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::imatcopy, + trans, m, n, alpha, &AB[0], lda, ldb, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy, trans, m, n, - alpha, &AB[0], lda, ldb, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::imatcopy, trans, + m, n, alpha, &AB[0], lda, ldb, dependencies); break; default: break; } @@ -160,7 +160,7 @@ int test(device *dev, oneapi::mkl::layout layout) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal_matrix(AB, AB_ref, oneapi::mkl::layout::column_major, size, 1, size, 10, + bool good = check_equal_matrix(AB, AB_ref, oneapi::mkl::layout::col_major, size, 1, size, 10, std::cout); return (int)good; @@ -174,6 +174,8 @@ TEST_P(ImatcopyUsmTests, RealSinglePrecision) { } TEST_P(ImatcopyUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -182,12 +184,14 @@ TEST_P(ImatcopyUsmTests, ComplexSinglePrecision) { } TEST_P(ImatcopyUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); } INSTANTIATE_TEST_SUITE_P(ImatcopyUsmTestSuite, ImatcopyUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/omatadd.cpp b/tests/unit_tests/blas/extensions/omatadd.cpp index 46d27bb9b..b2af98935 100644 --- a/tests/unit_tests/blas/extensions/omatadd.cpp +++ b/tests/unit_tests/blas/extensions/omatadd.cpp @@ -69,7 +69,7 @@ int test(device *dev, oneapi::mkl::layout layout) { int64_t size_a, size_b, size_c; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * n : lda * m; size_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; size_c = ldc * n; @@ -84,14 +84,14 @@ int test(device *dev, oneapi::mkl::layout layout) { vector> A(size_a), B(size_b), C(size_c), C_ref(size_c); - rand_matrix(A.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, - size_a, 1, size_a); - rand_matrix(B.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, - size_b, 1, size_b); - rand_matrix(C.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, - size_c, 1, size_c); - copy_matrix(C.data(), oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, - size_c, 1, size_c, C_ref.data()); + rand_matrix(A.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_a, + 1, size_a); + rand_matrix(B.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_b, + 1, size_b); + rand_matrix(C.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_c, + 1, size_c); + copy_matrix(C.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_c, + 1, size_c, C_ref.data()); // Call reference OMATADD. int m_ref = (int)m; @@ -127,7 +127,7 @@ int test(device *dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::omatadd(main_queue, transa, transb, m, n, alpha, A_buffer, lda, beta, B_buffer, ldb, C_buffer, ldc); @@ -141,15 +141,15 @@ int test(device *dev, oneapi::mkl::layout layout) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd, transa, - transb, m, n, alpha, A_buffer, lda, beta, B_buffer, ldb, - C_buffer, ldc); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd, + transa, transb, m, n, alpha, A_buffer, lda, beta, B_buffer, + ldb, C_buffer, ldc); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd, transa, - transb, m, n, alpha, A_buffer, lda, beta, B_buffer, ldb, - C_buffer, ldc); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd, transa, + transb, m, n, alpha, A_buffer, lda, beta, B_buffer, ldb, + C_buffer, ldc); break; default: break; } @@ -170,8 +170,8 @@ int test(device *dev, oneapi::mkl::layout layout) { // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); - bool good = check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::column_major, size_c, 1, + auto C_accessor = C_buffer.get_host_access(read_only); + bool good = check_equal_matrix(C_accessor, C_ref, oneapi::mkl::layout::col_major, size_c, 1, size_c, 10, std::cout); return (int)good; @@ -185,6 +185,8 @@ TEST_P(OmataddTests, RealSinglePrecision) { } TEST_P(OmataddTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -193,12 +195,14 @@ TEST_P(OmataddTests, ComplexSinglePrecision) { } TEST_P(OmataddTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); } INSTANTIATE_TEST_SUITE_P(OmataddTestSuite, OmataddTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/omatadd_usm.cpp b/tests/unit_tests/blas/extensions/omatadd_usm.cpp index 52dbdcc59..783f985b2 100644 --- a/tests/unit_tests/blas/extensions/omatadd_usm.cpp +++ b/tests/unit_tests/blas/extensions/omatadd_usm.cpp @@ -88,7 +88,7 @@ int test(device *dev, oneapi::mkl::layout layout) { int64_t size_a, size_b, size_c; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = (transa == oneapi::mkl::transpose::nontrans) ? lda * n : lda * m; size_b = (transb == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; size_c = ldc * n; @@ -109,13 +109,13 @@ int test(device *dev, oneapi::mkl::layout layout) { C.resize(size_c); C_ref.resize(size_c); - rand_matrix(A, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size_a, 1, + rand_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_a, 1, size_a); - rand_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size_b, 1, + rand_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_b, 1, size_b); - rand_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size_c, 1, + rand_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_c, 1, size_c); - copy_matrix(C, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size_c, 1, + copy_matrix(C, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_c, 1, size_c, C_ref); // Call reference OMATADD. @@ -131,7 +131,7 @@ int test(device *dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::omatadd(main_queue, transa, transb, m, n, alpha, &A[0], lda, beta, &B[0], ldb, &C[0], ldc, dependencies); @@ -146,15 +146,15 @@ int test(device *dev, oneapi::mkl::layout layout) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd, transa, - transb, m, n, alpha, &A[0], lda, beta, &B[0], ldb, &C[0], ldc, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatadd, + transa, transb, m, n, alpha, &A[0], lda, beta, &B[0], ldb, + &C[0], ldc, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd, transa, - transb, m, n, alpha, &A[0], lda, beta, &B[0], ldb, &C[0], ldc, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatadd, transa, + transb, m, n, alpha, &A[0], lda, beta, &B[0], ldb, &C[0], + ldc, dependencies); break; default: break; } @@ -175,8 +175,8 @@ int test(device *dev, oneapi::mkl::layout layout) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal_matrix(C, C_ref, oneapi::mkl::layout::column_major, size_c, 1, size_c, - 10, std::cout); + bool good = check_equal_matrix(C, C_ref, oneapi::mkl::layout::col_major, size_c, 1, size_c, 10, + std::cout); return (int)good; } @@ -189,6 +189,8 @@ TEST_P(OmataddUsmTests, RealSinglePrecision) { } TEST_P(OmataddUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -197,12 +199,14 @@ TEST_P(OmataddUsmTests, ComplexSinglePrecision) { } TEST_P(OmataddUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); } INSTANTIATE_TEST_SUITE_P(OmataddUsmTestSuite, OmataddUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/omatcopy.cpp b/tests/unit_tests/blas/extensions/omatcopy.cpp index b89391893..122ba2c79 100644 --- a/tests/unit_tests/blas/extensions/omatcopy.cpp +++ b/tests/unit_tests/blas/extensions/omatcopy.cpp @@ -76,7 +76,7 @@ int test(device *dev, oneapi::mkl::layout layout) { int64_t size_a, size_b; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda * n; size_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; break; @@ -123,7 +123,7 @@ int test(device *dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::omatcopy(main_queue, trans, m, n, alpha, A_buffer, lda, B_buffer, ldb); break; @@ -135,13 +135,13 @@ int test(device *dev, oneapi::mkl::layout layout) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy, trans, m, - n, alpha, A_buffer, lda, B_buffer, ldb); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy, + trans, m, n, alpha, A_buffer, lda, B_buffer, ldb); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy, trans, m, n, - alpha, A_buffer, lda, B_buffer, ldb); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy, trans, + m, n, alpha, A_buffer, lda, B_buffer, ldb); break; default: break; } @@ -163,8 +163,8 @@ int test(device *dev, oneapi::mkl::layout layout) { // Compare the results of reference implementation and DPC++ implementation. - auto B_accessor = B_buffer.template get_host_access(read_only); - bool good = check_equal_matrix(B_accessor, B_ref, oneapi::mkl::layout::column_major, size_b, 1, + auto B_accessor = B_buffer.get_host_access(read_only); + bool good = check_equal_matrix(B_accessor, B_ref, oneapi::mkl::layout::col_major, size_b, 1, size_b, 10, std::cout); return (int)good; @@ -178,6 +178,8 @@ TEST_P(OmatcopyTests, RealSinglePrecision) { } TEST_P(OmatcopyTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -186,12 +188,14 @@ TEST_P(OmatcopyTests, ComplexSinglePrecision) { } TEST_P(OmatcopyTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); } INSTANTIATE_TEST_SUITE_P(OmatcopyTestSuite, OmatcopyTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/extensions/omatcopy2.cpp b/tests/unit_tests/blas/extensions/omatcopy2.cpp new file mode 100644 index 000000000..d0407c324 --- /dev/null +++ b/tests/unit_tests/blas/extensions/omatcopy2.cpp @@ -0,0 +1,201 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "allocator_helper.hpp" +#include "cblas.h" +#include "oneapi/mkl/detail/config.hpp" +#include "oneapi/mkl.hpp" +#include "onemkl_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; +using std::vector; + +extern std::vector devices; + +namespace { + +template +int test(device *dev, oneapi::mkl::layout layout) { + // Prepare data. + int64_t m, n; + int64_t lda, ldb; + int64_t stride_a, stride_b; + oneapi::mkl::transpose trans; + fp alpha; + + stride_a = 1 + std::rand() % 50; + stride_b = 1 + std::rand() % 50; + m = 1 + std::rand() % 50; + n = 1 + std::rand() % 50; + lda = stride_a * (std::max(m, n) - 1) + 1; + ldb = stride_b * (std::max(m, n) - 1) + 1; + alpha = rand_scalar(); + trans = rand_trans(); + + int64_t size_a, size_b; + + switch (layout) { + case oneapi::mkl::layout::col_major: + size_a = lda * n; + size_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; + break; + case oneapi::mkl::layout::row_major: + size_a = lda * m; + size_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * m : ldb * n; + break; + default: break; + } + + vector> A(size_a), B(size_b), B_ref(size_b); + + rand_matrix(A.data(), layout, oneapi::mkl::transpose::nontrans, m, n, lda); + rand_matrix(B.data(), layout, trans, m, n, ldb); + copy_matrix(B.data(), oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_b, + 1, size_b, B_ref.data()); + + // Call reference OMATCOPY2. + int64_t m_ref = m; + int64_t n_ref = n; + int64_t lda_ref = lda; + int64_t ldb_ref = ldb; + int64_t stride_a_ref = stride_a; + int64_t stride_b_ref = stride_b; + omatcopy2_ref(layout, trans, m_ref, n_ref, alpha, A.data(), lda_ref, stride_a_ref, B_ref.data(), + ldb_ref, stride_b_ref); + + // Call DPC++ OMATCOPY2 + + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const &e) { + std::cout << "Caught asynchronous SYCL exception during OMATCOPY2:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + queue main_queue(*dev, exception_handler); + + buffer A_buffer(A.data(), range<1>(A.size())); + buffer B_buffer(B.data(), range<1>(B.size())); + + try { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::mkl::layout::col_major: + oneapi::mkl::blas::column_major::omatcopy2(main_queue, trans, m, n, alpha, A_buffer, + lda, stride_a, B_buffer, ldb, stride_b); + break; + case oneapi::mkl::layout::row_major: + oneapi::mkl::blas::row_major::omatcopy2(main_queue, trans, m, n, alpha, A_buffer, + lda, stride_a, B_buffer, ldb, stride_b); + break; + default: break; + } +#else + switch (layout) { + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy2, + trans, m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, + stride_b); + break; + case oneapi::mkl::layout::row_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy2, trans, + m, n, alpha, A_buffer, lda, stride_a, B_buffer, ldb, + stride_b); + break; + default: break; + } +#endif + } + catch (exception const &e) { + std::cout << "Caught synchronous SYCL exception during OMATCOPY2:\n" + << e.what() << std::endl; + print_error_code(e); + } + + catch (const oneapi::mkl::unimplemented &e) { + return test_skipped; + } + + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of OMATCOPY2:\n" << error.what() << std::endl; + } + + // Compare the results of reference implementation and DPC++ implementation. + + auto B_accessor = B_buffer.get_host_access(read_only); + bool good = check_equal_matrix(B_accessor, B_ref, oneapi::mkl::layout::col_major, size_b, 1, + size_b, 10, std::cout); + + return (int)good; +} + +class Omatcopy2Tests + : public ::testing::TestWithParam> {}; + +TEST_P(Omatcopy2Tests, RealSinglePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); +} + +TEST_P(Omatcopy2Tests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); +} + +TEST_P(Omatcopy2Tests, ComplexSinglePrecision) { + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); +} + +TEST_P(Omatcopy2Tests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Omatcopy2TestSuite, Omatcopy2Tests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::mkl::layout::col_major, + oneapi::mkl::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/blas/extensions/omatcopy2_usm.cpp b/tests/unit_tests/blas/extensions/omatcopy2_usm.cpp new file mode 100644 index 000000000..d2103d243 --- /dev/null +++ b/tests/unit_tests/blas/extensions/omatcopy2_usm.cpp @@ -0,0 +1,210 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif +#include "allocator_helper.hpp" +#include "cblas.h" +#include "oneapi/mkl/detail/config.hpp" +#include "oneapi/mkl.hpp" +#include "onemkl_blas_helper.hpp" +#include "reference_blas_templates.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +using namespace sycl; +using std::vector; + +extern std::vector devices; + +namespace { + +template +int test(device *dev, oneapi::mkl::layout layout) { + // Catch asynchronous exceptions. + auto exception_handler = [](exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (exception const &e) { + std::cout << "Caught asynchronous SYCL exception during OMATCOPY2:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + queue main_queue(*dev, exception_handler); + context cxt = main_queue.get_context(); + event done; + std::vector dependencies; + + // Prepare data. + int64_t m, n; + int64_t lda, ldb; + int64_t stride_a, stride_b; + oneapi::mkl::transpose trans; + fp alpha; + + stride_a = 1 + std::rand() % 50; + stride_b = 1 + std::rand() % 50; + m = 1 + std::rand() % 50; + n = 1 + std::rand() % 50; + lda = stride_a * (std::max(m, n) - 1) + 1; + ldb = stride_b * (std::max(m, n) - 1) + 1; + alpha = rand_scalar(); + trans = rand_trans(); + + int64_t size_a, size_b; + + switch (layout) { + case oneapi::mkl::layout::col_major: + size_a = lda * n; + size_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; + break; + case oneapi::mkl::layout::row_major: + size_a = lda * m; + size_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * m : ldb * n; + break; + default: break; + } + + auto ua = usm_allocator(cxt, *dev); + vector A(ua), B(ua), B_ref(ua); + + A.resize(size_a); + B.resize(size_b); + B_ref.resize(size_b); + + rand_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_a, 1, + size_a); + rand_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_b, 1, + size_b); + copy_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_b, 1, + size_b, B_ref); + + // Call reference OMATCOPY2. + int64_t m_ref = m; + int64_t n_ref = n; + int64_t lda_ref = lda; + int64_t ldb_ref = ldb; + int64_t stride_a_ref = stride_a; + int64_t stride_b_ref = stride_b; + omatcopy2_ref(layout, trans, m_ref, n_ref, alpha, A.data(), lda_ref, stride_a_ref, B_ref.data(), + ldb_ref, stride_b_ref); + + // Call DPC++ OMATCOPY2 + try { +#ifdef CALL_RT_API + switch (layout) { + case oneapi::mkl::layout::col_major: + done = oneapi::mkl::blas::column_major::omatcopy2(main_queue, trans, m, n, alpha, + &A[0], lda, stride_a, &B[0], ldb, + stride_b, dependencies); + break; + case oneapi::mkl::layout::row_major: + done = oneapi::mkl::blas::row_major::omatcopy2(main_queue, trans, m, n, alpha, + &A[0], lda, stride_a, &B[0], ldb, + stride_b, dependencies); + break; + default: break; + } + done.wait(); +#else + switch (layout) { + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy2, + trans, m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, + stride_b, dependencies); + break; + case oneapi::mkl::layout::row_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy2, trans, + m, n, alpha, &A[0], lda, stride_a, &B[0], ldb, stride_b, + dependencies); + break; + default: break; + } + main_queue.wait(); +#endif + } + catch (exception const &e) { + std::cout << "Caught synchronous SYCL exception during OMATCOPY2:\n" + << e.what() << std::endl; + print_error_code(e); + } + + catch (const oneapi::mkl::unimplemented &e) { + return test_skipped; + } + + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of OMATCOPY2:\n" << error.what() << std::endl; + } + + // Compare the results of reference implementation and DPC++ implementation. + bool good = check_equal_matrix(B, B_ref, oneapi::mkl::layout::col_major, size_b, 1, size_b, 10, + std::cout); + + return (int)good; +} + +class Omatcopy2UsmTests + : public ::testing::TestWithParam> {}; + +TEST_P(Omatcopy2UsmTests, RealSinglePrecision) { + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); +} + +TEST_P(Omatcopy2UsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); +} + +TEST_P(Omatcopy2UsmTests, ComplexSinglePrecision) { + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); +} + +TEST_P(Omatcopy2UsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Omatcopy2UsmTestSuite, Omatcopy2UsmTests, + ::testing::Combine(testing::ValuesIn(devices), + testing::Values(oneapi::mkl::layout::col_major, + oneapi::mkl::layout::row_major)), + ::LayoutDeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/blas/extensions/omatcopy_usm.cpp b/tests/unit_tests/blas/extensions/omatcopy_usm.cpp index 270c14696..ac9ba2d5c 100644 --- a/tests/unit_tests/blas/extensions/omatcopy_usm.cpp +++ b/tests/unit_tests/blas/extensions/omatcopy_usm.cpp @@ -85,7 +85,7 @@ int test(device *dev, oneapi::mkl::layout layout) { int64_t size_a, size_b; switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: size_a = lda * n; size_b = (trans == oneapi::mkl::transpose::nontrans) ? ldb * n : ldb * m; break; @@ -103,11 +103,11 @@ int test(device *dev, oneapi::mkl::layout layout) { B.resize(size_b); B_ref.resize(size_b); - rand_matrix(A, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size_a, 1, + rand_matrix(A, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_a, 1, size_a); - rand_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size_b, 1, + rand_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_b, 1, size_b); - copy_matrix(B, oneapi::mkl::layout::column_major, oneapi::mkl::transpose::nontrans, size_b, 1, + copy_matrix(B, oneapi::mkl::layout::col_major, oneapi::mkl::transpose::nontrans, size_b, 1, size_b, B_ref); // Call reference OMATCOPY. @@ -121,7 +121,7 @@ int test(device *dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::omatcopy( main_queue, trans, m, n, alpha, &A[0], lda, &B[0], ldb, dependencies); break; @@ -134,13 +134,13 @@ int test(device *dev, oneapi::mkl::layout layout) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy, trans, m, - n, alpha, &A[0], lda, &B[0], ldb, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::omatcopy, + trans, m, n, alpha, &A[0], lda, &B[0], ldb, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy, trans, m, n, - alpha, &A[0], lda, &B[0], ldb, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::omatcopy, trans, + m, n, alpha, &A[0], lda, &B[0], ldb, dependencies); break; default: break; } @@ -162,8 +162,8 @@ int test(device *dev, oneapi::mkl::layout layout) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal_matrix(B, B_ref, oneapi::mkl::layout::column_major, size_b, 1, size_b, - 10, std::cout); + bool good = check_equal_matrix(B, B_ref, oneapi::mkl::layout::col_major, size_b, 1, size_b, 10, + std::cout); return (int)good; } @@ -176,6 +176,8 @@ TEST_P(OmatcopyUsmTests, RealSinglePrecision) { } TEST_P(OmatcopyUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } @@ -184,12 +186,14 @@ TEST_P(OmatcopyUsmTests, ComplexSinglePrecision) { } TEST_P(OmatcopyUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()))); } INSTANTIATE_TEST_SUITE_P(OmatcopyUsmTestSuite, OmatcopyUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/include/onemkl_blas_helper.hpp b/tests/unit_tests/blas/include/onemkl_blas_helper.hpp index 9c8b3055f..5489aaa61 100644 --- a/tests/unit_tests/blas/include/onemkl_blas_helper.hpp +++ b/tests/unit_tests/blas/include/onemkl_blas_helper.hpp @@ -62,8 +62,8 @@ inline CBLAS_OFFSET convert_to_cblas_offset(oneapi::mkl::offset offsetc) { } inline CBLAS_LAYOUT convert_to_cblas_layout(oneapi::mkl::layout is_column) { - return is_column == oneapi::mkl::layout::column_major ? CBLAS_LAYOUT::CblasColMajor - : CBLAS_LAYOUT::CblasRowMajor; + return is_column == oneapi::mkl::layout::col_major ? CBLAS_LAYOUT::CblasColMajor + : CBLAS_LAYOUT::CblasRowMajor; } static const CBLAS_TRANSPOSE fcblastrans[] = { CblasNoTrans, CblasTrans, CblasConjTrans }; diff --git a/tests/unit_tests/blas/include/reference_blas_templates.hpp b/tests/unit_tests/blas/include/reference_blas_templates.hpp index 73d919412..6d184ba75 100644 --- a/tests/unit_tests/blas/include/reference_blas_templates.hpp +++ b/tests/unit_tests/blas/include/reference_blas_templates.hpp @@ -22,7 +22,9 @@ #include #include +#include #include "cblas.h" +#include "oneapi/mkl/types.hpp" #include "test_helper.hpp" #include "reference_blas_wrappers.hpp" @@ -1979,7 +1981,7 @@ template void omatcopy_ref(oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int64_t m, int64_t n, fp alpha, fp *A, int64_t lda, fp *B, int64_t ldb) { int64_t logical_m, logical_n; - if (layout == oneapi::mkl::layout::column_major) { + if (layout == oneapi::mkl::layout::col_major) { logical_m = m; logical_n = n; } @@ -2011,11 +2013,57 @@ void omatcopy_ref(oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int6 } } +template +void omatcopy2_ref(oneapi::mkl::layout layout, oneapi::mkl::transpose trans, const int64_t &m, + const int64_t &n, const fp &alpha, const fp *in_matrix, const int64_t &ld_in, + const int64_t &inc_in, fp *out_matrix, const int64_t &ld_out, + const int64_t inc_out) { + int64_t logical_m, logical_n; + if (layout == oneapi::mkl::layout::col_major) { + logical_m = m; + logical_n = n; + } + else { + logical_m = n; + logical_n = m; + } + if (trans == oneapi::mkl::transpose::trans) { + for (int64_t i = 0; i < logical_m; ++i) { + for (int64_t j = 0; j < logical_n; ++j) { + { + out_matrix[j * inc_out + i * ld_out] = + alpha * in_matrix[i * inc_in + j * ld_in]; + } + } + } + } + else if (trans == oneapi::mkl::transpose::nontrans) { + for (int i = 0; i < logical_n; ++i) { + for (int j = 0; j < logical_m; ++j) { + { + out_matrix[j * inc_out + i * ld_out] = + alpha * in_matrix[j * inc_in + i * ld_in]; + } + } + } + } + else { + for (int64_t i = 0; i < logical_m; ++i) { + for (int64_t j = 0, c = 0; j < logical_n; ++j, ++c) { + out_matrix[j * inc_out + i * ld_out] = + alpha * sametype_conj(in_matrix[i * inc_in + j * ld_in]); + } + } + } + + return; +} + template void imatcopy_ref(oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int64_t m, int64_t n, fp alpha, fp *A, int64_t lda, int64_t ldb) { int64_t logical_m, logical_n; - if (layout == oneapi::mkl::layout::column_major) { + if (layout == oneapi::mkl::layout::col_major) { logical_m = m; logical_n = n; } @@ -2070,7 +2118,7 @@ void omatadd_ref(oneapi::mkl::layout layout, oneapi::mkl::transpose transa, oneapi::mkl::transpose transb, int64_t m, int64_t n, fp alpha, fp *A, int64_t lda, fp beta, fp *B, int64_t ldb, fp *C, int64_t ldc) { int64_t logical_m, logical_n; - if (layout == oneapi::mkl::layout::column_major) { + if (layout == oneapi::mkl::layout::col_major) { logical_m = m; logical_n = n; } diff --git a/tests/unit_tests/blas/include/test_common.hpp b/tests/unit_tests/blas/include/test_common.hpp index ed4d61804..5d607991e 100644 --- a/tests/unit_tests/blas/include/test_common.hpp +++ b/tests/unit_tests/blas/include/test_common.hpp @@ -80,8 +80,8 @@ constexpr T matrix_size(oneapi::mkl::transpose trans, T m, T n, T ldm) { } template constexpr T matrix_size(oneapi::mkl::layout layout, oneapi::mkl::transpose trans, T m, T n, T ldm) { - return (layout == oneapi::mkl::layout::column_major) ? outer_dimension(trans, m, n) * ldm - : inner_dimension(trans, m, n) * ldm; + return (layout == oneapi::mkl::layout::col_major) ? outer_dimension(trans, m, n) * ldm + : inner_dimension(trans, m, n) * ldm; } // SYCL buffer creation helper. @@ -235,7 +235,7 @@ void copy_matrix(vec_src &src, oneapi::mkl::layout layout, oneapi::mkl::transpos using T_data = typename vec_dest::value_type; dest.resize(matrix_size(layout, trans, m, n, ld)); if (((trans == oneapi::mkl::transpose::nontrans) && - (layout == oneapi::mkl::layout::column_major)) || + (layout == oneapi::mkl::layout::col_major)) || ((trans != oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::row_major))) { for (int j = 0; j < n; j++) @@ -249,21 +249,21 @@ void copy_matrix(vec_src &src, oneapi::mkl::layout layout, oneapi::mkl::transpos } } -template -void copy_matrix(fp *src, oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int m, int n, - int ld, fp *dest) { +template +void copy_matrix(fp_src *src, oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int m, + int n, int ld, fp_dst *dest) { if (((trans == oneapi::mkl::transpose::nontrans) && - (layout == oneapi::mkl::layout::column_major)) || + (layout == oneapi::mkl::layout::col_major)) || ((trans != oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::row_major))) { for (int j = 0; j < n; j++) for (int i = 0; i < m; i++) - dest[i + j * ld] = (fp)src[i + j * ld]; + dest[i + j * ld] = (fp_dst)src[i + j * ld]; } else { for (int i = 0; i < m; i++) for (int j = 0; j < n; j++) - dest[j + i * ld] = (fp)src[j + i * ld]; + dest[j + i * ld] = (fp_dst)src[j + i * ld]; } } @@ -293,7 +293,7 @@ void rand_matrix(vec &M, oneapi::mkl::layout layout, oneapi::mkl::transpose tran M.resize(matrix_size(layout, trans, m, n, ld)); if (((trans == oneapi::mkl::transpose::nontrans) && - (layout == oneapi::mkl::layout::column_major)) || + (layout == oneapi::mkl::layout::col_major)) || ((trans != oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::row_major))) { for (int j = 0; j < n; j++) @@ -311,7 +311,7 @@ template void rand_matrix(fp *M, oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int m, int n, int ld) { if (((trans == oneapi::mkl::transpose::nontrans) && - (layout == oneapi::mkl::layout::column_major)) || + (layout == oneapi::mkl::layout::col_major)) || ((trans != oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::row_major))) { for (int j = 0; j < n; j++) @@ -333,7 +333,7 @@ void rand_trsm_matrix(vec &M, oneapi::mkl::layout layout, oneapi::mkl::transpose M.resize(matrix_size(layout, trans, m, n, ld)); if (((trans == oneapi::mkl::transpose::nontrans) && - (layout == oneapi::mkl::layout::column_major)) || + (layout == oneapi::mkl::layout::col_major)) || ((trans != oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::row_major))) { for (int j = 0; j < n; j++) @@ -359,7 +359,7 @@ template void rand_trsm_matrix(fp *M, oneapi::mkl::layout layout, oneapi::mkl::transpose trans, int m, int n, int ld) { if (((trans == oneapi::mkl::transpose::nontrans) && - (layout == oneapi::mkl::layout::column_major)) || + (layout == oneapi::mkl::layout::col_major)) || ((trans != oneapi::mkl::transpose::nontrans) && (layout == oneapi::mkl::layout::row_major))) { for (int j = 0; j < n; j++) @@ -392,7 +392,7 @@ void rand_tpsv_matrix(vec &M, oneapi::mkl::layout layout, oneapi::mkl::uplo uppe M.resize((m * (m + 1)) / 2); for (j = 0; j < m; j++) { - if (layout == oneapi::mkl::layout::column_major) { + if (layout == oneapi::mkl::layout::col_major) { start = (upper_lower == oneapi::mkl::uplo::U) ? 0 : j; end = (upper_lower == oneapi::mkl::uplo::U) ? j : m - 1; } @@ -417,7 +417,7 @@ void rand_tbsv_matrix(vec &M, oneapi::mkl::layout layout, oneapi::mkl::uplo uppe rand_trsm_matrix(tmp, layout, trans, m, m, ld); M.resize(matrix_size(layout, trans, m, m, ld)); - if (((layout == oneapi::mkl::layout::column_major) && (upper_lower == oneapi::mkl::uplo::U)) || + if (((layout == oneapi::mkl::layout::col_major) && (upper_lower == oneapi::mkl::uplo::U)) || ((layout == oneapi::mkl::layout::row_major) && (upper_lower == oneapi::mkl::uplo::L))) { for (j = 0; j < m; j++) { n = k - j; @@ -460,6 +460,13 @@ typename std::enable_if::value, bool>::type check_equal(fp return (x == x_ref); } +template +bool check_equal_ptr(sycl::queue queue, fp *x, fp x_ref, int error_mag) { + fp x_host; + queue.memcpy(&x_host, x, sizeof(fp)).wait(); + return check_equal(x_host, x_ref, error_mag); +} + template bool check_equal_trsm(fp x, fp x_ref, int error_mag) { using fp_real = typename complex_info::real_type; @@ -487,6 +494,13 @@ bool check_equal(fp x, fp x_ref, int error_mag, std::ostream &out) { return good; } +template +bool check_equal_ptr(sycl::queue queue, fp *x, fp x_ref, int error_mag, std::ostream &out) { + fp x_host; + queue.memcpy(&x_host, x, sizeof(fp)).wait(); + return check_equal(x_host, x_ref, error_mag, out); +} + template bool check_equal_vector(const fp *v, const fp *v_ref, int n, int inc, int error_mag, std::ostream &out) { @@ -556,7 +570,7 @@ bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, int m, int idx, count = 0; for (int j = 0; j < n; j++) { for (int i = 0; i < m; i++) { - idx = (layout == oneapi::mkl::layout::column_major) ? i + j * ld : j + i * ld; + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; if (!check_equal(M[idx], M_ref[idx], error_mag)) { out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] << " vs. Reference " << M_ref[idx] << std::endl; @@ -578,7 +592,7 @@ bool check_equal_matrix(const fp *M, const fp *M_ref, oneapi::mkl::layout layout int idx, count = 0; for (int j = 0; j < n; j++) { for (int i = 0; i < m; i++) { - idx = (layout == oneapi::mkl::layout::column_major) ? i + j * ld : j + i * ld; + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; if (!check_equal(M[idx], M_ref[idx], error_mag)) { out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] << " vs. Reference " << M_ref[idx] << std::endl; @@ -601,7 +615,7 @@ bool check_equal_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, int idx, count = 0; for (int j = 0; j < n; j++) { for (int i = 0; i < m; i++) { - idx = (layout == oneapi::mkl::layout::column_major) ? i + j * ld : j + i * ld; + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; if (((upper_lower == oneapi::mkl::uplo::upper) && (j >= i)) || ((upper_lower == oneapi::mkl::uplo::lower) && (j <= i))) { if (!check_equal(M[idx], M_ref[idx], error_mag)) { @@ -626,7 +640,7 @@ bool check_equal_trsm_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, i int idx, count = 0; for (int j = 0; j < n; j++) { for (int i = 0; i < m; i++) { - idx = (layout == oneapi::mkl::layout::column_major) ? i + j * ld : j + i * ld; + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; if (!check_equal_trsm(M[idx], M_ref[idx], error_mag)) { out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] << " vs. Reference " << M_ref[idx] << std::endl; @@ -641,4 +655,57 @@ bool check_equal_trsm_matrix(acc1 &M, acc2 &M_ref, oneapi::mkl::layout layout, i return good; } +// Helper for using std::result_of for evalutation operator[] return type +template +struct access_index { + auto operator()(T M) { + return M[0]; + } +}; + +// Helper for checking if a matrix/vector/accessor structure returns an integral type +template +constexpr bool is_matrix_type_integral() { + return std::is_integral_v< + std::remove_reference_t(T)>::type>>; +} + +template +typename std::enable_if::value, bool>::type check_almost_equal_int( + fp x, fp x_ref, int error_mag) { + return (std::abs(x - x_ref) <= error_mag); +} + +template +bool check_almost_equal_matrix_int(Ta &M, Tb &M_ref, oneapi::mkl::layout layout, int m, int n, + int ld, int error_mag, std::ostream &out) { + static_assert(is_matrix_type_integral() && is_matrix_type_integral()); + bool good = true; + int idx, count = 0; + for (int j = 0; j < n; j++) { + for (int i = 0; i < m; i++) { + idx = (layout == oneapi::mkl::layout::col_major) ? i + j * ld : j + i * ld; + if (!check_almost_equal_int(M[idx], M_ref[idx], error_mag)) { + out << "Difference in entry (" << i << ',' << j << "): DPC++ " << M[idx] + << " vs. Reference " << M_ref[idx] << std::endl; + good = false; + count++; + if (count > MAX_NUM_PRINT) + return good; + } + } + } + + return good; +} + +template +bool check_almost_equal_matrix(Ta &M, Tb &M_ref, oneapi::mkl::layout layout, int m, int n, int ld, + int error_mag, std::ostream &out) { + // Only call if returned dtype is integral + if constexpr (is_matrix_type_integral() && is_matrix_type_integral()) + return check_almost_equal_matrix_int(M, M_ref, layout, m, n, ld, error_mag, out); + return check_equal_matrix(M, M_ref, layout, m, n, ld, error_mag, out); +} + #endif /* header guard */ diff --git a/tests/unit_tests/blas/level1/asum.cpp b/tests/unit_tests/blas/level1/asum.cpp index 503f75045..6969789e3 100644 --- a/tests/unit_tests/blas/level1/asum.cpp +++ b/tests/unit_tests/blas/level1/asum.cpp @@ -82,7 +82,7 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::asum(main_queue, N, x_buffer, incx, result_buffer); break; case oneapi::mkl::layout::row_major: @@ -92,13 +92,13 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::asum, N, x_buffer, - incx, result_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::asum, N, + x_buffer, incx, result_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::asum, N, x_buffer, - incx, result_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::asum, N, x_buffer, + incx, result_buffer); break; default: break; } @@ -119,7 +119,7 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) { // Compare the results of reference implementation and DPC++ implementation. - auto result_accessor = result_buffer.template get_host_access(read_only); + auto result_accessor = result_buffer.get_host_access(read_only); bool good = check_equal(result_accessor[0], result_ref, N, std::cout); return (int)good; @@ -138,6 +138,8 @@ TEST_P(AsumTests, RealSinglePrecision) { } TEST_P(AsumTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( (::test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP( @@ -156,6 +158,8 @@ TEST_P(AsumTests, ComplexSinglePrecision) { } TEST_P(AsumTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), @@ -166,7 +170,7 @@ TEST_P(AsumTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(AsumTestSuite, AsumTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/asum_usm.cpp b/tests/unit_tests/blas/level1/asum_usm.cpp index ac27fe16b..b42799abd 100644 --- a/tests/unit_tests/blas/level1/asum_usm.cpp +++ b/tests/unit_tests/blas/level1/asum_usm.cpp @@ -45,7 +45,7 @@ extern std::vector devices; namespace { -template +template int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -81,12 +81,21 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) { // Call DPC++ ASUM. - auto result_p = (fp_res*)oneapi::mkl::malloc_shared(64, sizeof(fp_res), *dev, cxt); + fp_res* result_p; + if constexpr (alloc_type == usm::alloc::shared) { + result_p = (fp_res*)oneapi::mkl::malloc_shared(64, sizeof(fp_res), *dev, cxt); + } + else if constexpr (alloc_type == usm::alloc::device) { + result_p = (fp_res*)oneapi::mkl::malloc_device(64, sizeof(fp_res), *dev, cxt); + } + else { + throw std::runtime_error("Bad alloc_type"); + } try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::asum(main_queue, N, x.data(), incx, result_p, dependencies); break; @@ -99,13 +108,13 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::asum, N, x.data(), - incx, result_p, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::asum, N, + x.data(), incx, result_p, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::asum, N, x.data(), - incx, result_p, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::asum, N, x.data(), + incx, result_p, dependencies); break; default: break; } @@ -127,9 +136,9 @@ int test(device* dev, oneapi::mkl::layout layout, int64_t N, int64_t incx) { // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal(*result_p, result_ref, N, std::cout); + bool good = check_equal_ptr(main_queue, result_p, result_ref, N, std::cout); - oneapi::mkl::free_shared(result_p, cxt); + oneapi::mkl::free_usm(result_p, cxt); return (int)good; } @@ -142,15 +151,21 @@ TEST_P(AsumUsmTests, RealSinglePrecision) { (::test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP( (::test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1))); + EXPECT_TRUEORSKIP((::test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP( (::test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3))); } TEST_P(AsumUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( (::test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP( (::test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1))); + EXPECT_TRUEORSKIP((::test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP( (::test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3))); } @@ -160,22 +175,28 @@ TEST_P(AsumUsmTests, ComplexSinglePrecision) { std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP((::test, float>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1))); + EXPECT_TRUEORSKIP((::test, float, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP((::test, float>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3))); } TEST_P(AsumUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1))); + EXPECT_TRUEORSKIP((::test, double, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3))); } INSTANTIATE_TEST_SUITE_P(AsumUsmTestSuite, AsumUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/axpby.cpp b/tests/unit_tests/blas/level1/axpby.cpp index a7d4c86bf..d43f9beda 100644 --- a/tests/unit_tests/blas/level1/axpby.cpp +++ b/tests/unit_tests/blas/level1/axpby.cpp @@ -85,7 +85,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::axpby(main_queue, N, alpha, x_buffer, incx, beta, y_buffer, incy); break; @@ -97,13 +97,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpby, N, alpha, - x_buffer, incx, beta, y_buffer, incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpby, N, + alpha, x_buffer, incx, beta, y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpby, N, alpha, - x_buffer, incx, beta, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpby, N, alpha, + x_buffer, incx, beta, y_buffer, incy); break; default: break; } @@ -124,7 +124,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, N, incy, N, std::cout); return (int)good; @@ -144,6 +144,8 @@ TEST_P(AxpbyTests, RealSinglePrecision) { test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2, alpha, beta)); } TEST_P(AxpbyTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP( @@ -164,6 +166,8 @@ TEST_P(AxpbyTests, ComplexSinglePrecision) { 1357, -3, -2, alpha, beta)); } TEST_P(AxpbyTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -176,7 +180,7 @@ TEST_P(AxpbyTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(AxpbyTestSuite, AxpbyTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/axpby_usm.cpp b/tests/unit_tests/blas/level1/axpby_usm.cpp index 23e9bd188..ae85ca8f1 100644 --- a/tests/unit_tests/blas/level1/axpby_usm.cpp +++ b/tests/unit_tests/blas/level1/axpby_usm.cpp @@ -87,7 +87,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::axpby(main_queue, N, alpha, x.data(), incx, beta, y.data(), incy, dependencies); break; @@ -100,13 +100,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpby, N, alpha, - x.data(), incx, beta, y.data(), incy, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpby, N, + alpha, x.data(), incx, beta, y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpby, N, alpha, - x.data(), incx, beta, y.data(), incy, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpby, N, alpha, + x.data(), incx, beta, y.data(), incy, dependencies); break; default: break; } @@ -147,6 +147,8 @@ TEST_P(AxpbyUsmTests, RealSinglePrecision) { test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2, alpha, beta)); } TEST_P(AxpbyUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP( @@ -167,6 +169,8 @@ TEST_P(AxpbyUsmTests, ComplexSinglePrecision) { 1357, -3, -2, alpha, beta)); } TEST_P(AxpbyUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -179,7 +183,7 @@ TEST_P(AxpbyUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(AxpbyUsmTestSuite, AxpbyUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/axpy.cpp b/tests/unit_tests/blas/level1/axpy.cpp index 37b217250..c81f2902d 100644 --- a/tests/unit_tests/blas/level1/axpy.cpp +++ b/tests/unit_tests/blas/level1/axpy.cpp @@ -85,7 +85,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::axpy(main_queue, N, alpha, x_buffer, incx, y_buffer, incy); break; @@ -97,13 +97,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy, N, alpha, - x_buffer, incx, y_buffer, incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy, N, alpha, + x_buffer, incx, y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy, N, alpha, - x_buffer, incx, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy, N, alpha, + x_buffer, incx, y_buffer, incy); break; default: break; } @@ -124,7 +124,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, N, incy, N, std::cout); return (int)good; @@ -143,6 +143,8 @@ TEST_P(AxpyTests, RealSinglePrecision) { test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2, alpha)); } TEST_P(AxpyTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP( test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3, alpha)); @@ -161,6 +163,8 @@ TEST_P(AxpyTests, ComplexSinglePrecision) { 1357, -3, -2, alpha)); } TEST_P(AxpyTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3, alpha)); @@ -172,7 +176,7 @@ TEST_P(AxpyTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(AxpyTestSuite, AxpyTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/axpy_usm.cpp b/tests/unit_tests/blas/level1/axpy_usm.cpp index 1c87c4109..da68f173c 100644 --- a/tests/unit_tests/blas/level1/axpy_usm.cpp +++ b/tests/unit_tests/blas/level1/axpy_usm.cpp @@ -87,7 +87,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::axpy(main_queue, N, alpha, x.data(), incx, y.data(), incy, dependencies); break; @@ -100,13 +100,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy, N, alpha, - x.data(), incx, y.data(), incy, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::axpy, N, alpha, + x.data(), incx, y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy, N, alpha, - x.data(), incx, y.data(), incy, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::axpy, N, alpha, + x.data(), incx, y.data(), incy, dependencies); break; default: break; } @@ -146,6 +146,8 @@ TEST_P(AxpyUsmTests, RealSinglePrecision) { test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2, alpha)); } TEST_P(AxpyUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP( test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3, alpha)); @@ -164,6 +166,8 @@ TEST_P(AxpyUsmTests, ComplexSinglePrecision) { 1357, -3, -2, alpha)); } TEST_P(AxpyUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3, alpha)); @@ -175,7 +179,7 @@ TEST_P(AxpyUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(AxpyUsmTestSuite, AxpyUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/copy.cpp b/tests/unit_tests/blas/level1/copy.cpp index c88aa0db3..87a1c2f1b 100644 --- a/tests/unit_tests/blas/level1/copy.cpp +++ b/tests/unit_tests/blas/level1/copy.cpp @@ -84,7 +84,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::copy(main_queue, N, x_buffer, incx, y_buffer, incy); break; @@ -95,13 +95,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy, N, x_buffer, - incx, y_buffer, incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy, N, + x_buffer, incx, y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy, N, x_buffer, - incx, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy, N, x_buffer, + incx, y_buffer, incy); break; default: break; } @@ -122,7 +122,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, N, incy, N, std::cout); return (int)good; @@ -137,6 +137,8 @@ TEST_P(CopyTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); } TEST_P(CopyTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); @@ -150,6 +152,8 @@ TEST_P(CopyTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); } TEST_P(CopyTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP( @@ -160,7 +164,7 @@ TEST_P(CopyTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(CopyTestSuite, CopyTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/copy_usm.cpp b/tests/unit_tests/blas/level1/copy_usm.cpp index ae9ffbd59..0f491015b 100644 --- a/tests/unit_tests/blas/level1/copy_usm.cpp +++ b/tests/unit_tests/blas/level1/copy_usm.cpp @@ -86,7 +86,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::copy(main_queue, N, x.data(), incx, y.data(), incy, dependencies); break; @@ -99,13 +99,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy, N, x.data(), - incx, y.data(), incy, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::copy, N, + x.data(), incx, y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy, N, x.data(), - incx, y.data(), incy, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::copy, N, x.data(), + incx, y.data(), incy, dependencies); break; default: break; } @@ -141,6 +141,8 @@ TEST_P(CopyUsmTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); } TEST_P(CopyUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); @@ -154,6 +156,8 @@ TEST_P(CopyUsmTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); } TEST_P(CopyUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP( @@ -164,7 +168,7 @@ TEST_P(CopyUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(CopyUsmTestSuite, CopyUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/dot.cpp b/tests/unit_tests/blas/level1/dot.cpp index bc510d162..11cb09bcc 100644 --- a/tests/unit_tests/blas/level1/dot.cpp +++ b/tests/unit_tests/blas/level1/dot.cpp @@ -84,7 +84,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::dot(main_queue, N, x_buffer, incx, y_buffer, incy, result_buffer); break; @@ -96,13 +96,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dot, N, x_buffer, - incx, y_buffer, incy, result_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dot, N, + x_buffer, incx, y_buffer, incy, result_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dot, N, x_buffer, incx, - y_buffer, incy, result_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dot, N, x_buffer, + incx, y_buffer, incy, result_buffer); break; default: break; } @@ -123,7 +123,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { // Compare the results of reference implementation and DPC++ implementation. - auto result_accessor = result_buffer.template get_host_access(read_only); + auto result_accessor = result_buffer.get_host_access(read_only); bool good = check_equal(result_accessor[0], result_ref, N, std::cout); return (int)good; @@ -140,6 +140,8 @@ TEST_P(DotTests, RealSinglePrecision) { (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2))); } TEST_P(DotTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3))); EXPECT_TRUEORSKIP( @@ -148,6 +150,7 @@ TEST_P(DotTests, RealDoublePrecision) { (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2))); } TEST_P(DotTests, RealDoubleSinglePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3))); EXPECT_TRUEORSKIP( @@ -158,7 +161,7 @@ TEST_P(DotTests, RealDoubleSinglePrecision) { INSTANTIATE_TEST_SUITE_P(DotTestSuite, DotTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/dot_usm.cpp b/tests/unit_tests/blas/level1/dot_usm.cpp index 3b1a7f049..b8780c75d 100644 --- a/tests/unit_tests/blas/level1/dot_usm.cpp +++ b/tests/unit_tests/blas/level1/dot_usm.cpp @@ -45,7 +45,7 @@ extern std::vector devices; namespace { -template +template int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -81,12 +81,21 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { // Call DPC++ DOT. - auto result_p = (fp_res*)oneapi::mkl::malloc_shared(64, sizeof(fp_res), *dev, cxt); + fp_res* result_p; + if constexpr (alloc_type == usm::alloc::shared) { + result_p = (fp_res*)oneapi::mkl::malloc_shared(64, sizeof(fp_res), *dev, cxt); + } + else if constexpr (alloc_type == usm::alloc::device) { + result_p = (fp_res*)oneapi::mkl::malloc_device(64, sizeof(fp_res), *dev, cxt); + } + else { + throw std::runtime_error("Bad alloc_type"); + } try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::dot(main_queue, N, x.data(), incx, y.data(), incy, result_p, dependencies); break; @@ -99,13 +108,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dot, N, x.data(), - incx, y.data(), incy, result_p, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dot, N, + x.data(), incx, y.data(), incy, result_p, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dot, N, x.data(), incx, - y.data(), incy, result_p, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dot, N, x.data(), + incx, y.data(), incy, result_p, dependencies); break; default: break; } @@ -126,9 +135,9 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { } // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal(*result_p, result_ref, N, std::cout); + bool good = check_equal_ptr(main_queue, result_p, result_ref, N, std::cout); - oneapi::mkl::free_shared(result_p, cxt); + oneapi::mkl::free_usm(result_p, cxt); return (int)good; } @@ -141,29 +150,38 @@ TEST_P(DotUsmTests, RealSinglePrecision) { (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1))); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 101, 1, 1))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2))); } TEST_P(DotUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1))); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1, 1))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2))); } TEST_P(DotUsmTests, RealDoubleSinglePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1))); + EXPECT_TRUEORSKIP((test( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1, 1))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2))); } INSTANTIATE_TEST_SUITE_P(DotUsmTestSuite, DotUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/dotc.cpp b/tests/unit_tests/blas/level1/dotc.cpp index ed3adb2c9..cb8d0fc37 100644 --- a/tests/unit_tests/blas/level1/dotc.cpp +++ b/tests/unit_tests/blas/level1/dotc.cpp @@ -86,7 +86,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::dotc(main_queue, N, x_buffer, incx, y_buffer, incy, result_buffer); break; @@ -98,13 +98,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dotc, N, x_buffer, - incx, y_buffer, incy, result_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dotc, N, + x_buffer, incx, y_buffer, incy, result_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dotc, N, x_buffer, - incx, y_buffer, incy, result_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dotc, N, x_buffer, + incx, y_buffer, incy, result_buffer); break; default: break; } @@ -125,7 +125,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { // Compare the results of reference implementation and DPC++ implementation. - auto result_accessor = result_buffer.template get_host_access(read_only); + auto result_accessor = result_buffer.get_host_access(read_only); bool good = check_equal(result_accessor[0], result_reference, N, std::cout); return (int)good; @@ -143,6 +143,8 @@ TEST_P(DotcTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); } TEST_P(DotcTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP( @@ -153,7 +155,7 @@ TEST_P(DotcTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(DotcTestSuite, DotcTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/dotc_usm.cpp b/tests/unit_tests/blas/level1/dotc_usm.cpp index 376af0fb1..ad05c9d3b 100644 --- a/tests/unit_tests/blas/level1/dotc_usm.cpp +++ b/tests/unit_tests/blas/level1/dotc_usm.cpp @@ -88,7 +88,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::dotc( main_queue, N, x.data(), incx, y.data(), incy, result_p, dependencies); break; @@ -101,13 +101,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dotc, N, x.data(), - incx, y.data(), incy, result_p, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dotc, N, + x.data(), incx, y.data(), incy, result_p, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dotc, N, x.data(), - incx, y.data(), incy, result_p, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dotc, N, x.data(), + incx, y.data(), incy, result_p, dependencies); break; default: break; } @@ -148,6 +148,8 @@ TEST_P(DotcUsmTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); } TEST_P(DotcUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP( @@ -158,7 +160,7 @@ TEST_P(DotcUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(DotcUsmTestSuite, DotcUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/dotu.cpp b/tests/unit_tests/blas/level1/dotu.cpp index b71a4b7d6..bbef3ad8c 100644 --- a/tests/unit_tests/blas/level1/dotu.cpp +++ b/tests/unit_tests/blas/level1/dotu.cpp @@ -86,7 +86,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::dotu(main_queue, N, x_buffer, incx, y_buffer, incy, result_buffer); break; @@ -98,13 +98,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dotu, N, x_buffer, - incx, y_buffer, incy, result_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dotu, N, + x_buffer, incx, y_buffer, incy, result_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dotu, N, x_buffer, - incx, y_buffer, incy, result_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dotu, N, x_buffer, + incx, y_buffer, incy, result_buffer); break; default: break; } @@ -125,7 +125,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { // Compare the results of reference implementation and DPC++ implementation. - auto result_accessor = result_buffer.template get_host_access(read_only); + auto result_accessor = result_buffer.get_host_access(read_only); bool good = check_equal(result_accessor[0], result_reference, N, std::cout); return (int)good; @@ -143,6 +143,8 @@ TEST_P(DotuTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); } TEST_P(DotuTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP( @@ -153,7 +155,7 @@ TEST_P(DotuTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(DotuTestSuite, DotuTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/dotu_usm.cpp b/tests/unit_tests/blas/level1/dotu_usm.cpp index 9bb9996a7..3f30bf5ff 100644 --- a/tests/unit_tests/blas/level1/dotu_usm.cpp +++ b/tests/unit_tests/blas/level1/dotu_usm.cpp @@ -88,7 +88,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::dotu( main_queue, N, x.data(), incx, y.data(), incy, result_p, dependencies); break; @@ -101,13 +101,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dotu, N, x.data(), - incx, y.data(), incy, result_p, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::dotu, N, + x.data(), incx, y.data(), incy, result_p, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dotu, N, x.data(), - incx, y.data(), incy, result_p, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::dotu, N, x.data(), + incx, y.data(), incy, result_p, dependencies); break; default: break; } @@ -147,6 +147,8 @@ TEST_P(DotuUsmTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, -2)); } TEST_P(DotuUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP( @@ -157,7 +159,7 @@ TEST_P(DotuUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(DotuUsmTestSuite, DotuUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/iamax.cpp b/tests/unit_tests/blas/level1/iamax.cpp index 3f77c57e7..977f12b5d 100644 --- a/tests/unit_tests/blas/level1/iamax.cpp +++ b/tests/unit_tests/blas/level1/iamax.cpp @@ -82,7 +82,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::iamax(main_queue, N, x_buffer, incx, result_buffer); break; @@ -93,13 +93,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::iamax, N, x_buffer, - incx, result_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::iamax, N, + x_buffer, incx, result_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::iamax, N, x_buffer, - incx, result_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::iamax, N, + x_buffer, incx, result_buffer); break; default: break; } @@ -120,7 +120,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Compare the results of reference implementation and DPC++ implementation. - auto result_accessor = result_buffer.template get_host_access(read_only); + auto result_accessor = result_buffer.get_host_access(read_only); bool good = check_equal(result_accessor[0], result_ref, 0, std::cout); return (int)good; @@ -135,6 +135,8 @@ TEST_P(IamaxTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IamaxTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); @@ -148,6 +150,8 @@ TEST_P(IamaxTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IamaxTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP( @@ -158,7 +162,7 @@ TEST_P(IamaxTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(IamaxTestSuite, IamaxTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/iamax_usm.cpp b/tests/unit_tests/blas/level1/iamax_usm.cpp index caba39edd..405a79532 100644 --- a/tests/unit_tests/blas/level1/iamax_usm.cpp +++ b/tests/unit_tests/blas/level1/iamax_usm.cpp @@ -45,7 +45,7 @@ extern std::vector devices; namespace { -template +template int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -80,12 +80,21 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Call DPC++ IAMAX. - auto result_p = (int64_t*)oneapi::mkl::malloc_shared(64, sizeof(int64_t), *dev, cxt); + int64_t* result_p; + if constexpr (alloc_type == usm::alloc::shared) { + result_p = (int64_t*)oneapi::mkl::malloc_shared(64, sizeof(int64_t), *dev, cxt); + } + else if constexpr (alloc_type == usm::alloc::device) { + result_p = (int64_t*)oneapi::mkl::malloc_device(64, sizeof(int64_t), *dev, cxt); + } + else { + throw std::runtime_error("Bad alloc_type"); + } try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::iamax(main_queue, N, x.data(), incx, result_p, dependencies); break; @@ -98,13 +107,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::iamax, N, x.data(), - incx, result_p, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::iamax, N, + x.data(), incx, result_p, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::iamax, N, x.data(), - incx, result_p, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::iamax, N, + x.data(), incx, result_p, dependencies); break; default: break; } @@ -126,8 +135,8 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal(*result_p, result_ref, 0, std::cout); - oneapi::mkl::free_shared(result_p, cxt); + bool good = check_equal_ptr(main_queue, result_p, result_ref, 0, std::cout); + oneapi::mkl::free_usm(result_p, cxt); return (int)good; } @@ -138,11 +147,17 @@ class IamaxUsmTests TEST_P(IamaxUsmTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IamaxUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IamaxUsmTests, ComplexSinglePrecision) { @@ -150,21 +165,27 @@ TEST_P(IamaxUsmTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); + EXPECT_TRUEORSKIP((test, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IamaxUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); + EXPECT_TRUEORSKIP((test, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } INSTANTIATE_TEST_SUITE_P(IamaxUsmTestSuite, IamaxUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/iamin.cpp b/tests/unit_tests/blas/level1/iamin.cpp index 08296f86e..a52862cb6 100644 --- a/tests/unit_tests/blas/level1/iamin.cpp +++ b/tests/unit_tests/blas/level1/iamin.cpp @@ -82,7 +82,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::iamin(main_queue, N, x_buffer, incx, result_buffer); break; @@ -93,13 +93,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::iamin, N, x_buffer, - incx, result_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::iamin, N, + x_buffer, incx, result_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::iamin, N, x_buffer, - incx, result_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::iamin, N, + x_buffer, incx, result_buffer); break; default: break; } @@ -120,7 +120,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Compare the results of reference implementation and DPC++ implementation. - auto result_accessor = result_buffer.template get_host_access(read_only); + auto result_accessor = result_buffer.get_host_access(read_only); bool good = check_equal(result_accessor[0], result_ref, 0, std::cout); return (int)good; @@ -135,6 +135,8 @@ TEST_P(IaminTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IaminTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); @@ -148,6 +150,8 @@ TEST_P(IaminTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IaminTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP( @@ -158,7 +162,7 @@ TEST_P(IaminTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(IaminTestSuite, IaminTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/iamin_usm.cpp b/tests/unit_tests/blas/level1/iamin_usm.cpp index f13cb5f28..a3523c8e7 100644 --- a/tests/unit_tests/blas/level1/iamin_usm.cpp +++ b/tests/unit_tests/blas/level1/iamin_usm.cpp @@ -45,7 +45,7 @@ extern std::vector devices; namespace { -template +template int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -80,12 +80,21 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Call DPC++ IAMIN. - auto result_p = (int64_t*)oneapi::mkl::malloc_shared(64, sizeof(int64_t), *dev, cxt); + int64_t* result_p; + if constexpr (alloc_type == usm::alloc::shared) { + result_p = (int64_t*)oneapi::mkl::malloc_shared(64, sizeof(int64_t), *dev, cxt); + } + else if constexpr (alloc_type == usm::alloc::device) { + result_p = (int64_t*)oneapi::mkl::malloc_device(64, sizeof(int64_t), *dev, cxt); + } + else { + throw std::runtime_error("Bad alloc_type"); + } try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::iamin(main_queue, N, x.data(), incx, result_p, dependencies); break; @@ -98,13 +107,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::iamin, N, x.data(), - incx, result_p, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::iamin, N, + x.data(), incx, result_p, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::iamin, N, x.data(), - incx, result_p, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::iamin, N, + x.data(), incx, result_p, dependencies); break; default: break; } @@ -126,8 +135,8 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal(*result_p, result_ref, 0, std::cout); - oneapi::mkl::free_shared(result_p, cxt); + bool good = check_equal_ptr(main_queue, result_p, result_ref, 0, std::cout); + oneapi::mkl::free_usm(result_p, cxt); return (int)good; } @@ -138,11 +147,17 @@ class IaminUsmTests TEST_P(IaminUsmTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); + EXPECT_TRUEORSKIP(( + test(std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IaminUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IaminUsmTests, ComplexSinglePrecision) { @@ -150,21 +165,27 @@ TEST_P(IaminUsmTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); + EXPECT_TRUEORSKIP((test, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } TEST_P(IaminUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2)); EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1)); + EXPECT_TRUEORSKIP((test, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3)); } INSTANTIATE_TEST_SUITE_P(IaminUsmTestSuite, IaminUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/nrm2.cpp b/tests/unit_tests/blas/level1/nrm2.cpp index 6ac81d1af..423cecb59 100644 --- a/tests/unit_tests/blas/level1/nrm2.cpp +++ b/tests/unit_tests/blas/level1/nrm2.cpp @@ -83,7 +83,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::nrm2(main_queue, N, x_buffer, incx, result_buffer); break; case oneapi::mkl::layout::row_major: @@ -93,13 +93,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::nrm2, N, x_buffer, - incx, result_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::nrm2, N, + x_buffer, incx, result_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::nrm2, N, x_buffer, - incx, result_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::nrm2, N, x_buffer, + incx, result_buffer); break; default: break; } @@ -120,7 +120,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Compare the results of reference implementation and DPC++ implementation. - auto result_accessor = result_buffer.template get_host_access(read_only); + auto result_accessor = result_buffer.get_host_access(read_only); bool good = check_equal(result_accessor[0], result_ref, N, std::cout); return (int)good; @@ -138,6 +138,8 @@ TEST_P(Nrm2Tests, RealSinglePrecision) { (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3))); } TEST_P(Nrm2Tests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP( @@ -154,6 +156,8 @@ TEST_P(Nrm2Tests, ComplexSinglePrecision) { std::get<1>(GetParam()), 1357, -3))); } TEST_P(Nrm2Tests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), @@ -164,7 +168,7 @@ TEST_P(Nrm2Tests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Nrm2TestSuite, Nrm2Tests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/nrm2_usm.cpp b/tests/unit_tests/blas/level1/nrm2_usm.cpp index cdaf077ef..8628738f4 100644 --- a/tests/unit_tests/blas/level1/nrm2_usm.cpp +++ b/tests/unit_tests/blas/level1/nrm2_usm.cpp @@ -45,7 +45,7 @@ extern std::vector devices; namespace { -template +template int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -81,12 +81,21 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Call DPC++ NRM2. - auto result_p = (fp_res*)oneapi::mkl::malloc_shared(64, sizeof(fp_res), *dev, cxt); + fp_res* result_p; + if constexpr (alloc_type == usm::alloc::shared) { + result_p = (fp_res*)oneapi::mkl::malloc_shared(64, sizeof(fp_res), *dev, cxt); + } + else if constexpr (alloc_type == usm::alloc::device) { + result_p = (fp_res*)oneapi::mkl::malloc_device(64, sizeof(fp_res), *dev, cxt); + } + else { + throw std::runtime_error("Bad alloc_type"); + } try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::nrm2(main_queue, N, x.data(), incx, result_p, dependencies); break; @@ -99,13 +108,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::nrm2, N, x.data(), - incx, result_p, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::nrm2, N, + x.data(), incx, result_p, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::nrm2, N, x.data(), - incx, result_p, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::nrm2, N, x.data(), + incx, result_p, dependencies); break; default: break; } @@ -127,8 +136,8 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx) { // Compare the results of reference implementation and DPC++ implementation. - bool good = check_equal(*result_p, result_ref, N, std::cout); - oneapi::mkl::free_shared(result_p, cxt); + bool good = check_equal_ptr(main_queue, result_p, result_ref, N, std::cout); + oneapi::mkl::free_usm(result_p, cxt); return (int)good; } @@ -141,14 +150,20 @@ TEST_P(Nrm2UsmTests, RealSinglePrecision) { (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1))); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3))); } TEST_P(Nrm2UsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1))); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3))); } @@ -157,21 +172,27 @@ TEST_P(Nrm2UsmTests, ComplexSinglePrecision) { std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP((test, float>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1))); + EXPECT_TRUEORSKIP((test, float, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP((test, float>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3))); } TEST_P(Nrm2UsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2))); EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1))); + EXPECT_TRUEORSKIP((test, double, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam()), 101, 1))); EXPECT_TRUEORSKIP((test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3))); } INSTANTIATE_TEST_SUITE_P(Nrm2UsmTestSuite, Nrm2UsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/rot.cpp b/tests/unit_tests/blas/level1/rot.cpp index 4bf810049..f65540182 100644 --- a/tests/unit_tests/blas/level1/rot.cpp +++ b/tests/unit_tests/blas/level1/rot.cpp @@ -86,7 +86,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp_ try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::rot(main_queue, N, x_buffer, incx, y_buffer, incy, c, s); break; @@ -98,13 +98,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp_ } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rot, N, x_buffer, - incx, y_buffer, incy, c, s); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rot, N, + x_buffer, incx, y_buffer, incy, c, s); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rot, N, x_buffer, incx, - y_buffer, incy, c, s); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rot, N, x_buffer, + incx, y_buffer, incy, c, s); break; default: break; } @@ -125,9 +125,9 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp_ // Compare the results of reference implementation and DPC++ implementation. - auto x_accessor = x_buffer.template get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good_x = check_equal_vector(x_accessor, x_ref, N, incx, N, std::cout); - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good_y = check_equal_vector(y_accessor, y_ref, N, incy, N, std::cout); bool good = good_x && good_y; @@ -149,6 +149,8 @@ TEST_P(RotTests, RealSinglePrecision) { (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -2, -3, c, s))); } TEST_P(RotTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double c(2.0); double s(-0.5); EXPECT_TRUEORSKIP( @@ -169,6 +171,8 @@ TEST_P(RotTests, ComplexSinglePrecision) { std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -2, -3, c, s))); } TEST_P(RotTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double c = 2.0; double s = -0.5; EXPECT_TRUEORSKIP((test, double>( @@ -181,7 +185,7 @@ TEST_P(RotTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(RotTestSuite, RotTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/rot_usm.cpp b/tests/unit_tests/blas/level1/rot_usm.cpp index 457f61d67..287ac285b 100644 --- a/tests/unit_tests/blas/level1/rot_usm.cpp +++ b/tests/unit_tests/blas/level1/rot_usm.cpp @@ -88,7 +88,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp_ try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::rot(main_queue, N, x.data(), incx, y.data(), incy, c, s, dependencies); break; @@ -101,13 +101,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp_ done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rot, N, x.data(), - incx, y.data(), incy, c, s, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rot, N, + x.data(), incx, y.data(), incy, c, s, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rot, N, x.data(), incx, - y.data(), incy, c, s, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rot, N, x.data(), + incx, y.data(), incy, c, s, dependencies); break; default: break; } @@ -150,6 +150,8 @@ TEST_P(RotUsmTests, RealSinglePrecision) { (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -2, -3, c, s))); } TEST_P(RotUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double c(2.0); double s(-0.5); EXPECT_TRUEORSKIP( @@ -170,6 +172,8 @@ TEST_P(RotUsmTests, ComplexSinglePrecision) { std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -2, -3, c, s))); } TEST_P(RotUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double c = 2.0; double s = -0.5; EXPECT_TRUEORSKIP((test, double>( @@ -182,7 +186,7 @@ TEST_P(RotUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(RotUsmTestSuite, RotUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/rotg.cpp b/tests/unit_tests/blas/level1/rotg.cpp index 50f8f4720..1a0d569d8 100644 --- a/tests/unit_tests/blas/level1/rotg.cpp +++ b/tests/unit_tests/blas/level1/rotg.cpp @@ -92,7 +92,7 @@ int test(device *dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::rotg(main_queue, a_buffer, b_buffer, c_buffer, s_buffer); break; @@ -104,13 +104,13 @@ int test(device *dev, oneapi::mkl::layout layout) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotg, a_buffer, - b_buffer, c_buffer, s_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotg, a_buffer, + b_buffer, c_buffer, s_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotg, a_buffer, - b_buffer, c_buffer, s_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotg, a_buffer, + b_buffer, c_buffer, s_buffer); break; default: break; } @@ -130,13 +130,13 @@ int test(device *dev, oneapi::mkl::layout layout) { } // Compare the results of reference implementation and DPC++ implementation. - auto a_accessor = a_buffer.template get_host_access(read_only); + auto a_accessor = a_buffer.get_host_access(read_only); bool good_a = check_equal(a_accessor[0], a_ref, 4, std::cout); - auto b_accessor = b_buffer.template get_host_access(read_only); + auto b_accessor = b_buffer.get_host_access(read_only); bool good_b = check_equal(b_accessor[0], b_ref, 4, std::cout); - auto s_accessor = s_buffer.template get_host_access(read_only); + auto s_accessor = s_buffer.get_host_access(read_only); bool good_s = check_equal(s_accessor[0], s_ref, 4, std::cout); - auto c_accessor = c_buffer.template get_host_access(read_only); + auto c_accessor = c_buffer.get_host_access(read_only); bool good_c = check_equal(c_accessor[0], c_ref, 4, std::cout); bool good = good_a && good_b && good_c && good_s; @@ -153,6 +153,8 @@ TEST_P(RotgTests, RealSinglePrecision) { EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam())))); } TEST_P(RotgTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam())))); EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam())))); EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam())))); @@ -166,6 +168,8 @@ TEST_P(RotgTests, ComplexSinglePrecision) { (test, float>(std::get<0>(GetParam()), std::get<1>(GetParam())))); } TEST_P(RotgTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( (test, double>(std::get<0>(GetParam()), std::get<1>(GetParam())))); EXPECT_TRUEORSKIP( @@ -176,7 +180,7 @@ TEST_P(RotgTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(RotgTestSuite, RotgTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/rotg_usm.cpp b/tests/unit_tests/blas/level1/rotg_usm.cpp index 72c096a94..de71a793d 100644 --- a/tests/unit_tests/blas/level1/rotg_usm.cpp +++ b/tests/unit_tests/blas/level1/rotg_usm.cpp @@ -45,7 +45,7 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -86,20 +86,34 @@ int test(device *dev, oneapi::mkl::layout layout) { ::rotg((fp_ref *)&a_ref, (fp_ref *)&b_ref, (fp_scalar *)&c_ref, (fp_ref *)&s_ref); // Call DPC++ ROTG. - fp *a_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); - fp *b_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); - fp *s_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); - fp_scalar *c_p = (fp_scalar *)oneapi::mkl::malloc_shared(64, sizeof(fp_scalar), *dev, cxt); + fp *a_p, *b_p, *s_p; + fp_scalar *c_p; + if constexpr (alloc_type == usm::alloc::shared) { + a_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); + b_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); + s_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); + c_p = (fp_scalar *)oneapi::mkl::malloc_shared(64, sizeof(fp_scalar), *dev, cxt); + } + else if constexpr (alloc_type == usm::alloc::device) { + a_p = (fp *)oneapi::mkl::malloc_device(64, sizeof(fp), *dev, cxt); + b_p = (fp *)oneapi::mkl::malloc_device(64, sizeof(fp), *dev, cxt); + s_p = (fp *)oneapi::mkl::malloc_device(64, sizeof(fp), *dev, cxt); + c_p = (fp_scalar *)oneapi::mkl::malloc_device(64, sizeof(fp_scalar), *dev, cxt); + } + else { + throw std::runtime_error("Bad alloc_type"); + } - a_p[0] = a; - b_p[0] = b; - s_p[0] = s; - c_p[0] = c; + main_queue.memcpy(a_p, &a, sizeof(fp)); + main_queue.memcpy(b_p, &b, sizeof(fp)); + main_queue.memcpy(s_p, &s, sizeof(fp)); + main_queue.memcpy(c_p, &c, sizeof(fp_scalar)); + main_queue.wait(); try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::rotg(main_queue, a_p, b_p, c_p, s_p, dependencies); break; @@ -112,13 +126,13 @@ int test(device *dev, oneapi::mkl::layout layout) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotg, a_p, b_p, c_p, - s_p, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotg, a_p, b_p, + c_p, s_p, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotg, a_p, b_p, c_p, - s_p, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotg, a_p, b_p, + c_p, s_p, dependencies); break; default: break; } @@ -140,17 +154,17 @@ int test(device *dev, oneapi::mkl::layout layout) { // Compare the results of reference implementation and DPC++ implementation. - bool good_a = check_equal(a_p[0], a_ref, 4, std::cout); - bool good_b = check_equal(b_p[0], b_ref, 4, std::cout); - bool good_s = check_equal(s_p[0], s_ref, 4, std::cout); - bool good_c = check_equal(c_p[0], c_ref, 4, std::cout); + bool good_a = check_equal_ptr(main_queue, a_p, a_ref, 4, std::cout); + bool good_b = check_equal_ptr(main_queue, b_p, b_ref, 4, std::cout); + bool good_s = check_equal_ptr(main_queue, s_p, s_ref, 4, std::cout); + bool good_c = check_equal_ptr(main_queue, c_p, c_ref, 4, std::cout); bool good = good_a && good_b && good_c && good_s; - oneapi::mkl::free_shared(a_p, cxt); - oneapi::mkl::free_shared(b_p, cxt); - oneapi::mkl::free_shared(s_p, cxt); - oneapi::mkl::free_shared(c_p, cxt); + oneapi::mkl::free_usm(a_p, cxt); + oneapi::mkl::free_usm(b_p, cxt); + oneapi::mkl::free_usm(s_p, cxt); + oneapi::mkl::free_usm(c_p, cxt); return (int)good; } @@ -160,22 +174,34 @@ class RotgUsmTests TEST_P(RotgUsmTests, RealSinglePrecision) { EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam())))); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam())))); } TEST_P(RotgUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), std::get<1>(GetParam())))); + EXPECT_TRUEORSKIP((test(std::get<0>(GetParam()), + std::get<1>(GetParam())))); } TEST_P(RotgUsmTests, ComplexSinglePrecision) { EXPECT_TRUEORSKIP( (test, float>(std::get<0>(GetParam()), std::get<1>(GetParam())))); + EXPECT_TRUEORSKIP((test, float, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam())))); } TEST_P(RotgUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( (test, double>(std::get<0>(GetParam()), std::get<1>(GetParam())))); + EXPECT_TRUEORSKIP((test, double, usm::alloc::device>( + std::get<0>(GetParam()), std::get<1>(GetParam())))); } INSTANTIATE_TEST_SUITE_P(RotgUsmTestSuite, RotgUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/rotm.cpp b/tests/unit_tests/blas/level1/rotm.cpp index 1364c7ae5..ab2c599bf 100644 --- a/tests/unit_tests/blas/level1/rotm.cpp +++ b/tests/unit_tests/blas/level1/rotm.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::rotm(main_queue, N, x_buffer, incx, y_buffer, incy, param_buffer); break; @@ -101,13 +101,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotm, N, x_buffer, - incx, y_buffer, incy, param_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotm, N, + x_buffer, incx, y_buffer, incy, param_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotm, N, x_buffer, - incx, y_buffer, incy, param_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotm, N, x_buffer, + incx, y_buffer, incy, param_buffer); break; default: break; } @@ -127,9 +127,9 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp } // Compare the results of reference implementation and DPC++ implementation. - auto x_accessor = x_buffer.template get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good_x = check_equal_vector(x_accessor, x_ref, N, incx, N, std::cout); - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good_y = check_equal_vector(y_accessor, y_ref, N, incy, N, std::cout); bool good = good_x && good_y; @@ -170,6 +170,8 @@ TEST_P(RotmTests, RealSinglePrecision) { test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1, flag)); } TEST_P(RotmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double flag(-1.0); EXPECT_TRUEORSKIP( test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3, flag)); @@ -202,7 +204,7 @@ TEST_P(RotmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(RotmTestSuite, RotmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/rotm_usm.cpp b/tests/unit_tests/blas/level1/rotm_usm.cpp index b749cc087..7723e096c 100644 --- a/tests/unit_tests/blas/level1/rotm_usm.cpp +++ b/tests/unit_tests/blas/level1/rotm_usm.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::rotm( main_queue, N, x.data(), incx, y.data(), incy, param.data(), dependencies); break; @@ -102,13 +102,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, fp done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotm, N, x.data(), - incx, y.data(), incy, param.data(), dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotm, N, + x.data(), incx, y.data(), incy, param.data(), dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotm, N, x.data(), - incx, y.data(), incy, param.data(), dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotm, N, x.data(), + incx, y.data(), incy, param.data(), dependencies); break; default: break; } @@ -171,6 +171,8 @@ TEST_P(RotmUsmTests, RealSinglePrecision) { test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1, flag)); } TEST_P(RotmUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double flag(-1.0); EXPECT_TRUEORSKIP( test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3, flag)); @@ -203,7 +205,7 @@ TEST_P(RotmUsmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(RotmUsmTestSuite, RotmUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/rotmg.cpp b/tests/unit_tests/blas/level1/rotmg.cpp index 2f713b2e0..f62bd1cf9 100644 --- a/tests/unit_tests/blas/level1/rotmg.cpp +++ b/tests/unit_tests/blas/level1/rotmg.cpp @@ -89,7 +89,7 @@ int test(device* dev, oneapi::mkl::layout layout) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::rotmg(main_queue, d1_buffer, d2_buffer, x1_buffer, y1, param_buffer); break; @@ -101,13 +101,13 @@ int test(device* dev, oneapi::mkl::layout layout) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotmg, d1_buffer, - d2_buffer, x1_buffer, y1, param_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotmg, + d1_buffer, d2_buffer, x1_buffer, y1, param_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotmg, d1_buffer, - d2_buffer, x1_buffer, y1, param_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotmg, d1_buffer, + d2_buffer, x1_buffer, y1, param_buffer); break; default: break; } @@ -130,15 +130,59 @@ int test(device* dev, oneapi::mkl::layout layout) { int error_mag = 50; - auto d1_accessor = d1_buffer.template get_host_access(read_only); + auto d1_accessor = d1_buffer.get_host_access(read_only); bool good_d1 = check_equal(d1_accessor[0], d1_ref, error_mag, std::cout); - auto d2_accessor = d2_buffer.template get_host_access(read_only); + auto d2_accessor = d2_buffer.get_host_access(read_only); bool good_d2 = check_equal(d2_accessor[0], d2_ref, error_mag, std::cout); - auto x1_accessor = x1_buffer.template get_host_access(read_only); + auto x1_accessor = x1_buffer.get_host_access(read_only); bool good_x1 = check_equal(x1_accessor[0], x1_ref, error_mag, std::cout); - auto param_accessor = param_buffer.template get_host_access(read_only); - bool good_param = check_equal_vector(param_accessor, param_ref, 5, 1, error_mag, std::cout); - bool good = good_d1 && good_d2 && good_x1 && good_param; + auto param_accessor = param_buffer.get_host_access(read_only); + + constexpr fp unit_matrix = -2; + constexpr fp rescaled_matrix = -1; + constexpr fp sltc_matrix = 0; + constexpr fp clts_matrix = 1; + + fp flag = param_accessor[0]; + fp h11 = param_accessor[1]; + fp h12 = param_accessor[3]; + fp h21 = param_accessor[2]; + fp h22 = param_accessor[4]; + + fp flag_ref = param_ref[0]; + fp h11_ref = param_ref[1]; + fp h12_ref = param_ref[3]; + fp h21_ref = param_ref[2]; + fp h22_ref = param_ref[4]; + + bool flag_good = (flag_ref == flag); + bool h11_good = true; + bool h12_good = true; + bool h21_good = true; + bool h22_good = true; + + /* Some values of param have to be ignored depending on the flag value since they are + * implementation defined */ + if (flag_ref != unit_matrix) { + if (flag_ref == sltc_matrix) { + h12_good = check_equal(h12, h12_ref, error_mag, std::cout); + h21_good = check_equal(h21, h21_ref, error_mag, std::cout); + } + else if (flag_ref == clts_matrix) { + h11_good = check_equal(h11, h11_ref, error_mag, std::cout); + h22_good = check_equal(h22, h22_ref, error_mag, std::cout); + } + else { + flag_good = flag_good && (flag == rescaled_matrix); + h11_good = check_equal(h11, h11_ref, error_mag, std::cout); + h12_good = check_equal(h12, h12_ref, error_mag, std::cout); + h21_good = check_equal(h21, h21_ref, error_mag, std::cout); + h22_good = check_equal(h22, h22_ref, error_mag, std::cout); + } + } + + bool good = + good_d1 && good_d2 && good_x1 && flag_good && h11_good && h12_good && h21_good && h22_good; return (int)good; } @@ -150,12 +194,14 @@ TEST_P(RotmgTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } TEST_P(RotmgTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); } INSTANTIATE_TEST_SUITE_P(RotmgTestSuite, RotmgTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/rotmg_usm.cpp b/tests/unit_tests/blas/level1/rotmg_usm.cpp index 8d0dadfbf..92eeee491 100644 --- a/tests/unit_tests/blas/level1/rotmg_usm.cpp +++ b/tests/unit_tests/blas/level1/rotmg_usm.cpp @@ -45,7 +45,7 @@ extern std::vector devices; namespace { -template +template int test(device *dev, oneapi::mkl::layout layout) { // Catch asynchronous exceptions. auto exception_handler = [](exception_list exceptions) { @@ -80,22 +80,35 @@ int test(device *dev, oneapi::mkl::layout layout) { d2_ref = d2; x1_ref = x1; + fp *d1_p, *d2_p, *x1_p; + if constexpr (alloc_type == usm::alloc::device) { + d1_p = (fp *)oneapi::mkl::malloc_device(64, sizeof(fp), *dev, cxt); + d2_p = (fp *)oneapi::mkl::malloc_device(64, sizeof(fp), *dev, cxt); + x1_p = (fp *)oneapi::mkl::malloc_device(64, sizeof(fp), *dev, cxt); + } + else if constexpr (alloc_type == usm::alloc::shared) { + d1_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); + d2_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); + x1_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); + } + else { + throw std::runtime_error("Bad alloc_type"); + } + main_queue.memcpy(d1_p, &d1, sizeof(fp)); + main_queue.memcpy(d2_p, &d2, sizeof(fp)); + main_queue.memcpy(x1_p, &x1, sizeof(fp)); + main_queue.wait(); + // Call Reference ROTMG. ::rotmg(&d1_ref, &d2_ref, &x1_ref, &y1, (fp *)param_ref.data()); // Call DPC++ ROTMG. - fp *d1_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); - fp *d2_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); - fp *x1_p = (fp *)oneapi::mkl::malloc_shared(64, sizeof(fp), *dev, cxt); - d1_p[0] = d1; - d2_p[0] = d2; - x1_p[0] = x1; try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::rotmg(main_queue, d1_p, d2_p, x1_p, y1, param.data(), dependencies); break; @@ -108,13 +121,13 @@ int test(device *dev, oneapi::mkl::layout layout) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotmg, d1_p, d2_p, - x1_p, y1, param.data(), dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::rotmg, d1_p, + d2_p, x1_p, y1, param.data(), dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotmg, d1_p, d2_p, - x1_p, y1, param.data(), dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::rotmg, d1_p, d2_p, + x1_p, y1, param.data(), dependencies); break; default: break; } @@ -134,17 +147,67 @@ int test(device *dev, oneapi::mkl::layout layout) { std::cout << "Error raised during execution of ROTMG:\n" << error.what() << std::endl; } + int error_mag = 50; + // Compare the results of reference implementation and DPC++ implementation. - bool good_d1 = check_equal(d1_p[0], d1_ref, 1, std::cout); - bool good_d2 = check_equal(d2_p[0], d2_ref, 1, std::cout); - bool good_x1 = check_equal(x1_p[0], x1_ref, 1, std::cout); - bool good_param = check_equal_vector(param, param_ref, 5, 1, 1, std::cout); - bool good = good_d1 && good_d2 && good_x1 && good_param; + bool good_d1 = check_equal_ptr(main_queue, d1_p, d1_ref, error_mag, std::cout); + bool good_d2 = check_equal_ptr(main_queue, d2_p, d2_ref, error_mag, std::cout); + bool good_x1 = check_equal_ptr(main_queue, x1_p, x1_ref, error_mag, std::cout); + + constexpr fp unit_matrix = -2; + constexpr fp rescaled_matrix = -1; + constexpr fp sltc_matrix = 0; + constexpr fp clts_matrix = 1; + + fp param_host[5]; + main_queue.memcpy(¶m_host, param.data(), sizeof(fp) * 5); + main_queue.wait(); + + fp flag = param_host[0]; + fp h11 = param_host[1]; + fp h12 = param_host[3]; + fp h21 = param_host[2]; + fp h22 = param_host[4]; + + fp flag_ref = param_ref[0]; + fp h11_ref = param_ref[1]; + fp h12_ref = param_ref[3]; + fp h21_ref = param_ref[2]; + fp h22_ref = param_ref[4]; + + bool flag_good = (flag_ref == flag); + bool h11_good = true; + bool h12_good = true; + bool h21_good = true; + bool h22_good = true; + + /* Some values of param have to be ignored depending on the flag value since they are + * implementation defined */ + if (flag_ref != unit_matrix) { + if (flag_ref == sltc_matrix) { + h12_good = check_equal(h12, h12_ref, error_mag, std::cout); + h21_good = check_equal(h21, h21_ref, error_mag, std::cout); + } + else if (flag_ref == clts_matrix) { + h11_good = check_equal(h11, h11_ref, error_mag, std::cout); + h22_good = check_equal(h22, h22_ref, error_mag, std::cout); + } + else { + flag_good = flag_good && (flag == rescaled_matrix); + h11_good = check_equal(h11, h11_ref, error_mag, std::cout); + h12_good = check_equal(h12, h12_ref, error_mag, std::cout); + h21_good = check_equal(h21, h21_ref, error_mag, std::cout); + h22_good = check_equal(h22, h22_ref, error_mag, std::cout); + } + } - oneapi::mkl::free_shared(d1_p, cxt); - oneapi::mkl::free_shared(d2_p, cxt); - oneapi::mkl::free_shared(x1_p, cxt); + bool good = + good_d1 && good_d2 && good_x1 && flag_good && h11_good && h12_good && h21_good && h22_good; + + oneapi::mkl::free_usm(d1_p, cxt); + oneapi::mkl::free_usm(d2_p, cxt); + oneapi::mkl::free_usm(x1_p, cxt); return (int)good; } @@ -154,14 +217,20 @@ class RotmgUsmTests TEST_P(RotmgUsmTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam())))); } TEST_P(RotmgUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()))); + EXPECT_TRUEORSKIP( + (test(std::get<0>(GetParam()), std::get<1>(GetParam())))); } INSTANTIATE_TEST_SUITE_P(RotmgUsmTestSuite, RotmgUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/scal.cpp b/tests/unit_tests/blas/level1/scal.cpp index 4265c890e..8901bb424 100644 --- a/tests/unit_tests/blas/level1/scal.cpp +++ b/tests/unit_tests/blas/level1/scal.cpp @@ -84,7 +84,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, fp_scalar alp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::scal(main_queue, N, alpha, x_buffer, incx); break; case oneapi::mkl::layout::row_major: @@ -94,13 +94,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, fp_scalar alp } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::scal, N, alpha, - x_buffer, incx); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::scal, N, alpha, + x_buffer, incx); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::scal, N, alpha, - x_buffer, incx); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::scal, N, alpha, + x_buffer, incx); break; default: break; } @@ -120,7 +120,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, fp_scalar alp } // Compare the results of reference implementation and DPC++ implementation. - auto x_accessor = x_buffer.template get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good = check_equal_vector(x_accessor, x_ref, N, incx, N, std::cout); return (int)good; @@ -137,6 +137,8 @@ TEST_P(ScalTests, RealSinglePrecision) { (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, alpha))); } TEST_P(ScalTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, alpha))); @@ -151,6 +153,8 @@ TEST_P(ScalTests, ComplexSinglePrecision) { std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, alpha))); } TEST_P(ScalTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP((test, std::complex>( std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, alpha))); @@ -165,6 +169,8 @@ TEST_P(ScalTests, ComplexRealSinglePrecision) { std::get<1>(GetParam()), 1357, -3, alpha))); } TEST_P(ScalTests, ComplexRealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP((test, double>( std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, alpha))); @@ -174,7 +180,7 @@ TEST_P(ScalTests, ComplexRealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(ScalTestSuite, ScalTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/scal_usm.cpp b/tests/unit_tests/blas/level1/scal_usm.cpp index ad74d5e3f..e669deb2d 100644 --- a/tests/unit_tests/blas/level1/scal_usm.cpp +++ b/tests/unit_tests/blas/level1/scal_usm.cpp @@ -87,7 +87,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, fp_scalar alp try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::scal(main_queue, N, alpha, x.data(), incx, dependencies); break; @@ -100,13 +100,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, fp_scalar alp done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::scal, N, alpha, - x.data(), incx, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::scal, N, alpha, + x.data(), incx, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::scal, N, alpha, - x.data(), incx, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::scal, N, alpha, + x.data(), incx, dependencies); break; default: break; } @@ -144,6 +144,8 @@ TEST_P(ScalUsmTests, RealSinglePrecision) { (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, alpha))); } TEST_P(ScalUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP( (test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, alpha))); @@ -158,6 +160,8 @@ TEST_P(ScalUsmTests, ComplexSinglePrecision) { std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -3, alpha))); } TEST_P(ScalUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP((test, std::complex>( std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, alpha))); @@ -172,6 +176,8 @@ TEST_P(ScalUsmTests, ComplexRealSinglePrecision) { std::get<1>(GetParam()), 1357, -3, alpha))); } TEST_P(ScalUsmTests, ComplexRealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP((test, double>( std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, alpha))); @@ -181,7 +187,7 @@ TEST_P(ScalUsmTests, ComplexRealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(ScalUsmTestSuite, ScalUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/sdsdot.cpp b/tests/unit_tests/blas/level1/sdsdot.cpp index 5cb8835dd..7293a3699 100644 --- a/tests/unit_tests/blas/level1/sdsdot.cpp +++ b/tests/unit_tests/blas/level1/sdsdot.cpp @@ -84,7 +84,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, flo try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::sdsdot(main_queue, N, alpha, x_buffer, incx, y_buffer, incy, result_buffer); break; @@ -96,13 +96,13 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, flo } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::sdsdot, N, alpha, - x_buffer, incx, y_buffer, incy, result_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::sdsdot, N, + alpha, x_buffer, incx, y_buffer, incy, result_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::sdsdot, N, alpha, - x_buffer, incx, y_buffer, incy, result_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::sdsdot, N, alpha, + x_buffer, incx, y_buffer, incy, result_buffer); break; default: break; } @@ -123,7 +123,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, flo // Compare the results of reference implementation and DPC++ implementation. - auto result_accessor = result_buffer.template get_host_access(read_only); + auto result_accessor = result_buffer.get_host_access(read_only); bool good = check_equal(result_accessor[0], result_ref, N, std::cout); return (int)good; @@ -133,6 +133,7 @@ class SdsdotTests : public ::testing::TestWithParam> {}; TEST_P(SdsdotTests, RealSinglePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3, 2.0)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -2, -3, 2.0)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1, 2.0)); @@ -140,7 +141,7 @@ TEST_P(SdsdotTests, RealSinglePrecision) { INSTANTIATE_TEST_SUITE_P(SdsdotTestSuite, SdsdotTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/sdsdot_usm.cpp b/tests/unit_tests/blas/level1/sdsdot_usm.cpp index 6ca37b97d..a5740516c 100644 --- a/tests/unit_tests/blas/level1/sdsdot_usm.cpp +++ b/tests/unit_tests/blas/level1/sdsdot_usm.cpp @@ -86,7 +86,7 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, flo try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::sdsdot( main_queue, N, alpha, x.data(), incx, y.data(), incy, result_p, dependencies); break; @@ -99,13 +99,14 @@ int test(device *dev, oneapi::mkl::layout layout, int N, int incx, int incy, flo done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::sdsdot, N, alpha, - x.data(), incx, y.data(), incy, result_p, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::sdsdot, N, + alpha, x.data(), incx, y.data(), incy, result_p, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::sdsdot, N, alpha, - x.data(), incx, y.data(), incy, result_p, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::sdsdot, N, alpha, + x.data(), incx, y.data(), incy, result_p, dependencies); break; default: break; } @@ -137,6 +138,7 @@ class SdsdotUsmTests : public ::testing::TestWithParam> {}; TEST_P(SdsdotUsmTests, RealSinglePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3, 2.0)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -2, -3, 2.0)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1, 2.0)); @@ -144,7 +146,7 @@ TEST_P(SdsdotUsmTests, RealSinglePrecision) { INSTANTIATE_TEST_SUITE_P(SdsdotUsmTestSuite, SdsdotUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/swap.cpp b/tests/unit_tests/blas/level1/swap.cpp index 5f59672d5..6c6721537 100644 --- a/tests/unit_tests/blas/level1/swap.cpp +++ b/tests/unit_tests/blas/level1/swap.cpp @@ -84,7 +84,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::swap(main_queue, N, x_buffer, incx, y_buffer, incy); break; @@ -95,13 +95,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::swap, N, x_buffer, - incx, y_buffer, incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::swap, N, + x_buffer, incx, y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::swap, N, x_buffer, - incx, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::swap, N, x_buffer, + incx, y_buffer, incy); break; default: break; } @@ -122,8 +122,8 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); - auto x_accessor = x_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good_y = check_equal_vector(y_accessor, y_ref, N, incy, N, std::cout); bool good_x = check_equal_vector(x_accessor, x_ref, N, incx, N, std::cout); bool good = good_x && good_y; @@ -140,6 +140,8 @@ TEST_P(SwapTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1)); } TEST_P(SwapTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -2, -3)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1)); @@ -153,6 +155,8 @@ TEST_P(SwapTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1)); } TEST_P(SwapTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP( @@ -163,7 +167,7 @@ TEST_P(SwapTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SwapTestSuite, SwapTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level1/swap_usm.cpp b/tests/unit_tests/blas/level1/swap_usm.cpp index c92d26ae8..de20f3eb7 100644 --- a/tests/unit_tests/blas/level1/swap_usm.cpp +++ b/tests/unit_tests/blas/level1/swap_usm.cpp @@ -86,7 +86,7 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::swap(main_queue, N, x.data(), incx, y.data(), incy, dependencies); break; @@ -99,13 +99,13 @@ int test(device* dev, oneapi::mkl::layout layout, int N, int incx, int incy) { done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::swap, N, x.data(), - incx, y.data(), incy, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::swap, N, + x.data(), incx, y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::swap, N, x.data(), - incx, y.data(), incy, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::swap, N, x.data(), + incx, y.data(), incy, dependencies); break; default: break; } @@ -143,6 +143,8 @@ TEST_P(SwapUsmTests, RealSinglePrecision) { EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1)); } TEST_P(SwapUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, -2, -3)); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1)); @@ -156,6 +158,8 @@ TEST_P(SwapUsmTests, ComplexSinglePrecision) { test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 1, 1)); } TEST_P(SwapUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP( test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 1357, 2, 3)); EXPECT_TRUEORSKIP( @@ -166,7 +170,7 @@ TEST_P(SwapUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SwapUsmTestSuite, SwapUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/gbmv.cpp b/tests/unit_tests/blas/level2/gbmv.cpp index 9b339daa7..94fcbc906 100644 --- a/tests/unit_tests/blas/level2/gbmv.cpp +++ b/tests/unit_tests/blas/level2/gbmv.cpp @@ -94,7 +94,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::gbmv(main_queue, transa, m, n, kl, ku, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); @@ -108,14 +108,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gbmv, transa, m, n, - kl, ku, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, - incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gbmv, transa, + m, n, kl, ku, alpha, A_buffer, lda, x_buffer, incx, beta, + y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gbmv, transa, m, n, kl, - ku, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gbmv, transa, m, + n, kl, ku, alpha, A_buffer, lda, x_buffer, incx, beta, + y_buffer, incy); break; default: break; } @@ -135,7 +136,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, } // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, y_len, incy, std::max(m, n), std::cout); return (int)good; @@ -167,6 +168,8 @@ TEST_P(GbmvTests, RealSinglePrecision) { 42)); } TEST_P(GbmvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -220,6 +223,8 @@ TEST_P(GbmvTests, ComplexSinglePrecision) { alpha, beta, 1, 1, 42)); } TEST_P(GbmvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -253,7 +258,7 @@ TEST_P(GbmvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GbmvTestSuite, GbmvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/gbmv_usm.cpp b/tests/unit_tests/blas/level2/gbmv_usm.cpp index f67ae34c2..9d92fcf7e 100644 --- a/tests/unit_tests/blas/level2/gbmv_usm.cpp +++ b/tests/unit_tests/blas/level2/gbmv_usm.cpp @@ -94,7 +94,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gbmv(main_queue, transa, m, n, kl, ku, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, dependencies); @@ -109,15 +109,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gbmv, transa, m, n, - kl, ku, alpha, A.data(), lda, x.data(), incx, beta, y.data(), - incy, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gbmv, transa, + m, n, kl, ku, alpha, A.data(), lda, x.data(), incx, beta, + y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gbmv, transa, m, n, kl, - ku, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gbmv, transa, m, + n, kl, ku, alpha, A.data(), lda, x.data(), incx, beta, + y.data(), incy, dependencies); break; default: break; } @@ -170,6 +170,8 @@ TEST_P(GbmvUsmTests, RealSinglePrecision) { 42)); } TEST_P(GbmvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -223,6 +225,8 @@ TEST_P(GbmvUsmTests, ComplexSinglePrecision) { alpha, beta, 1, 1, 42)); } TEST_P(GbmvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -256,7 +260,7 @@ TEST_P(GbmvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GbmvUsmTestSuite, GbmvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/gemv.cpp b/tests/unit_tests/blas/level2/gemv.cpp index b60f8d24c..3bfff4324 100644 --- a/tests/unit_tests/blas/level2/gemv.cpp +++ b/tests/unit_tests/blas/level2/gemv.cpp @@ -93,7 +93,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::gemv(main_queue, transa, m, n, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); break; @@ -105,13 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv, transa, m, n, - alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv, transa, + m, n, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, + incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv, transa, m, n, - alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv, transa, m, + n, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, + incy); break; default: break; } @@ -131,7 +133,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, } // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, y_len, incy, std::max(m, n), std::cout); return (int)good; @@ -158,6 +160,8 @@ TEST_P(GemvTests, RealSinglePrecision) { oneapi::mkl::transpose::trans, 25, 30, alpha, beta, 1, 1, 42)); } TEST_P(GemvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -210,6 +214,8 @@ TEST_P(GemvTests, ComplexSinglePrecision) { } TEST_P(GemvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -243,7 +249,7 @@ TEST_P(GemvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GemvTestSuite, GemvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/gemv_usm.cpp b/tests/unit_tests/blas/level2/gemv_usm.cpp index 5dca89b22..d1e726e38 100644 --- a/tests/unit_tests/blas/level2/gemv_usm.cpp +++ b/tests/unit_tests/blas/level2/gemv_usm.cpp @@ -93,7 +93,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemv(main_queue, transa, m, n, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, dependencies); @@ -108,15 +108,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv, transa, m, n, - alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemv, transa, + m, n, alpha, A.data(), lda, x.data(), incx, beta, y.data(), + incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv, transa, m, n, - alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemv, transa, m, + n, alpha, A.data(), lda, x.data(), incx, beta, y.data(), + incy, dependencies); break; default: break; } @@ -164,6 +164,8 @@ TEST_P(GemvUsmTests, RealSinglePrecision) { oneapi::mkl::transpose::trans, 25, 30, alpha, beta, 1, 1, 42)); } TEST_P(GemvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -214,6 +216,8 @@ TEST_P(GemvUsmTests, ComplexSinglePrecision) { beta, 1, 1, 42)); } TEST_P(GemvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -247,7 +251,7 @@ TEST_P(GemvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GemvUsmTestSuite, GemvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/ger.cpp b/tests/unit_tests/blas/level2/ger.cpp index 2d4160aac..3b32d2827 100644 --- a/tests/unit_tests/blas/level2/ger.cpp +++ b/tests/unit_tests/blas/level2/ger.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::ger(main_queue, m, n, alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; @@ -102,13 +102,13 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::ger, m, n, alpha, - x_buffer, incx, y_buffer, incy, A_buffer, lda); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::ger, m, n, + alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::ger, m, n, alpha, - x_buffer, incx, y_buffer, incy, A_buffer, lda); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::ger, m, n, alpha, + x_buffer, incx, y_buffer, incy, A_buffer, lda); break; default: break; } @@ -128,7 +128,7 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, m, n, lda, std::max(m, n), std::cout); @@ -148,6 +148,8 @@ TEST_P(GerTests, RealSinglePrecision) { test(std::get<0>(GetParam()), std::get<1>(GetParam()), 25, 30, alpha, 1, 1, 42)); } TEST_P(GerTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP( test(std::get<0>(GetParam()), std::get<1>(GetParam()), 25, 30, alpha, 2, 3, 42)); @@ -159,7 +161,7 @@ TEST_P(GerTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GerTestSuite, GerTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/ger_usm.cpp b/tests/unit_tests/blas/level2/ger_usm.cpp index f8fe34a94..87087f026 100644 --- a/tests/unit_tests/blas/level2/ger_usm.cpp +++ b/tests/unit_tests/blas/level2/ger_usm.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::ger(main_queue, m, n, alpha, x.data(), incx, y.data(), incy, A.data(), lda, dependencies); @@ -105,13 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::ger, m, n, alpha, - x.data(), incx, y.data(), incy, A.data(), lda, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::ger, m, n, + alpha, x.data(), incx, y.data(), incy, A.data(), lda, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::ger, m, n, alpha, - x.data(), incx, y.data(), incy, A.data(), lda, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::ger, m, n, alpha, + x.data(), incx, y.data(), incy, A.data(), lda, + dependencies); break; default: break; } @@ -151,6 +153,8 @@ TEST_P(GerUsmTests, RealSinglePrecision) { test(std::get<0>(GetParam()), std::get<1>(GetParam()), 25, 30, alpha, 1, 1, 42)); } TEST_P(GerUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP( test(std::get<0>(GetParam()), std::get<1>(GetParam()), 25, 30, alpha, 2, 3, 42)); @@ -162,7 +166,7 @@ TEST_P(GerUsmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GerUsmTestSuite, GerUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/gerc.cpp b/tests/unit_tests/blas/level2/gerc.cpp index 2aa7d02c0..c19c9f029 100644 --- a/tests/unit_tests/blas/level2/gerc.cpp +++ b/tests/unit_tests/blas/level2/gerc.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::gerc(main_queue, m, n, alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; @@ -102,13 +102,13 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gerc, m, n, alpha, - x_buffer, incx, y_buffer, incy, A_buffer, lda); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gerc, m, n, + alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gerc, m, n, alpha, - x_buffer, incx, y_buffer, incy, A_buffer, lda); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gerc, m, n, alpha, + x_buffer, incx, y_buffer, incy, A_buffer, lda); break; default: break; } @@ -128,7 +128,7 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, m, n, lda, std::max(m, n), std::cout); @@ -148,6 +148,8 @@ TEST_P(GercTests, ComplexSinglePrecision) { 25, 30, alpha, 1, 1, 42)); } TEST_P(GercTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 25, 30, alpha, 2, 3, 42)); @@ -159,7 +161,7 @@ TEST_P(GercTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GercTestSuite, GercTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/gerc_usm.cpp b/tests/unit_tests/blas/level2/gerc_usm.cpp index a019d10b1..b6473484d 100644 --- a/tests/unit_tests/blas/level2/gerc_usm.cpp +++ b/tests/unit_tests/blas/level2/gerc_usm.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gerc(main_queue, m, n, alpha, x.data(), incx, y.data(), incy, A.data(), lda, dependencies); @@ -105,13 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gerc, m, n, alpha, - x.data(), incx, y.data(), incy, A.data(), lda, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gerc, m, n, + alpha, x.data(), incx, y.data(), incy, A.data(), lda, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gerc, m, n, alpha, - x.data(), incx, y.data(), incy, A.data(), lda, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gerc, m, n, alpha, + x.data(), incx, y.data(), incy, A.data(), lda, + dependencies); break; default: break; } @@ -151,6 +153,8 @@ TEST_P(GercUsmTests, ComplexSinglePrecision) { 25, 30, alpha, 1, 1, 42)); } TEST_P(GercUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 25, 30, alpha, 2, 3, 42)); @@ -162,7 +166,7 @@ TEST_P(GercUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GercUsmTestSuite, GercUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/geru.cpp b/tests/unit_tests/blas/level2/geru.cpp index e2e0fd563..e0cb7c45d 100644 --- a/tests/unit_tests/blas/level2/geru.cpp +++ b/tests/unit_tests/blas/level2/geru.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::geru(main_queue, m, n, alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; @@ -102,13 +102,13 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::geru, m, n, alpha, - x_buffer, incx, y_buffer, incy, A_buffer, lda); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::geru, m, n, + alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::geru, m, n, alpha, - x_buffer, incx, y_buffer, incy, A_buffer, lda); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::geru, m, n, alpha, + x_buffer, incx, y_buffer, incy, A_buffer, lda); break; default: break; } @@ -128,7 +128,7 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, m, n, lda, std::max(m, n), std::cout); @@ -148,6 +148,8 @@ TEST_P(GeruTests, ComplexSinglePrecision) { 25, 30, alpha, 1, 1, 42)); } TEST_P(GeruTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 25, 30, alpha, 2, 3, 42)); @@ -159,7 +161,7 @@ TEST_P(GeruTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GeruTestSuite, GeruTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/geru_usm.cpp b/tests/unit_tests/blas/level2/geru_usm.cpp index 82e5467ad..1e882bd97 100644 --- a/tests/unit_tests/blas/level2/geru_usm.cpp +++ b/tests/unit_tests/blas/level2/geru_usm.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::geru(main_queue, m, n, alpha, x.data(), incx, y.data(), incy, A.data(), lda, dependencies); @@ -105,13 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, int m, int n, fp alpha, int in done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::geru, m, n, alpha, - x.data(), incx, y.data(), incy, A.data(), lda, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::geru, m, n, + alpha, x.data(), incx, y.data(), incy, A.data(), lda, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::geru, m, n, alpha, - x.data(), incx, y.data(), incy, A.data(), lda, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::geru, m, n, alpha, + x.data(), incx, y.data(), incy, A.data(), lda, + dependencies); break; default: break; } @@ -151,6 +153,8 @@ TEST_P(GeruUsmTests, ComplexSinglePrecision) { 25, 30, alpha, 1, 1, 42)); } TEST_P(GeruUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), 25, 30, alpha, 2, 3, 42)); @@ -162,7 +166,7 @@ TEST_P(GeruUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GeruUsmTestSuite, GeruUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hbmv.cpp b/tests/unit_tests/blas/level2/hbmv.cpp index 54dc3eaa3..119aef32a 100644 --- a/tests/unit_tests/blas/level2/hbmv.cpp +++ b/tests/unit_tests/blas/level2/hbmv.cpp @@ -91,7 +91,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::hbmv(main_queue, upper_lower, n, k, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); @@ -104,14 +104,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hbmv, upper_lower, - n, k, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, - incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hbmv, + upper_lower, n, k, alpha, A_buffer, lda, x_buffer, incx, + beta, y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hbmv, upper_lower, n, - k, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hbmv, upper_lower, + n, k, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, + incy); break; default: break; } @@ -131,7 +132,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, n, incy, n, std::cout); return (int)good; @@ -163,6 +164,8 @@ TEST_P(HbmvTests, ComplexSinglePrecision) { 42)); } TEST_P(HbmvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -187,7 +190,7 @@ TEST_P(HbmvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HbmvTestSuite, HbmvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hbmv_usm.cpp b/tests/unit_tests/blas/level2/hbmv_usm.cpp index cdbbd8d7d..60305cb93 100644 --- a/tests/unit_tests/blas/level2/hbmv_usm.cpp +++ b/tests/unit_tests/blas/level2/hbmv_usm.cpp @@ -92,7 +92,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::hbmv(main_queue, upper_lower, n, k, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, dependencies); @@ -107,15 +107,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hbmv, upper_lower, - n, k, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hbmv, + upper_lower, n, k, alpha, A.data(), lda, x.data(), incx, + beta, y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hbmv, upper_lower, n, - k, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hbmv, upper_lower, + n, k, alpha, A.data(), lda, x.data(), incx, beta, y.data(), + incy, dependencies); break; default: break; } @@ -168,6 +168,8 @@ TEST_P(HbmvUsmTests, ComplexSinglePrecision) { 42)); } TEST_P(HbmvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -192,7 +194,7 @@ TEST_P(HbmvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HbmvUsmTestSuite, HbmvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hemv.cpp b/tests/unit_tests/blas/level2/hemv.cpp index 364a38aa5..3636e3774 100644 --- a/tests/unit_tests/blas/level2/hemv.cpp +++ b/tests/unit_tests/blas/level2/hemv.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::hemv(main_queue, upper_lower, n, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); break; @@ -102,13 +102,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hemv, upper_lower, - n, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hemv, + upper_lower, n, alpha, A_buffer, lda, x_buffer, incx, beta, + y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hemv, upper_lower, n, - alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hemv, upper_lower, + n, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, + incy); break; default: break; } @@ -128,7 +130,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, n, incy, n, std::cout); return (int)good; @@ -160,6 +162,8 @@ TEST_P(HemvTests, ComplexSinglePrecision) { 42)); } TEST_P(HemvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -184,7 +188,7 @@ TEST_P(HemvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HemvTestSuite, HemvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hemv_usm.cpp b/tests/unit_tests/blas/level2/hemv_usm.cpp index b620da877..a1b8093fc 100644 --- a/tests/unit_tests/blas/level2/hemv_usm.cpp +++ b/tests/unit_tests/blas/level2/hemv_usm.cpp @@ -91,7 +91,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::hemv(main_queue, upper_lower, n, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, dependencies); @@ -106,15 +106,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hemv, upper_lower, - n, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hemv, + upper_lower, n, alpha, A.data(), lda, x.data(), incx, beta, + y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hemv, upper_lower, n, - alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hemv, upper_lower, + n, alpha, A.data(), lda, x.data(), incx, beta, y.data(), + incy, dependencies); break; default: break; } @@ -167,6 +167,8 @@ TEST_P(HemvUsmTests, ComplexSinglePrecision) { 42)); } TEST_P(HemvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -191,7 +193,7 @@ TEST_P(HemvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HemvUsmTestSuite, HemvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/her.cpp b/tests/unit_tests/blas/level2/her.cpp index 53a9b7d2d..46ae9a879 100644 --- a/tests/unit_tests/blas/level2/her.cpp +++ b/tests/unit_tests/blas/level2/her.cpp @@ -87,7 +87,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::her(main_queue, upper_lower, n, alpha, x_buffer, incx, A_buffer, lda); break; @@ -99,13 +99,13 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her, upper_lower, n, - alpha, x_buffer, incx, A_buffer, lda); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her, + upper_lower, n, alpha, x_buffer, incx, A_buffer, lda); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her, upper_lower, n, - alpha, x_buffer, incx, A_buffer, lda); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her, upper_lower, + n, alpha, x_buffer, incx, A_buffer, lda); break; default: break; } @@ -125,7 +125,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, n, n, lda, n, std::cout); return (int)good; @@ -156,6 +156,8 @@ TEST_P(HerTests, ComplexSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 42))); } TEST_P(HerTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP( (test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -179,7 +181,7 @@ TEST_P(HerTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HerTestSuite, HerTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/her2.cpp b/tests/unit_tests/blas/level2/her2.cpp index c6d868b64..e98c5cc8b 100644 --- a/tests/unit_tests/blas/level2/her2.cpp +++ b/tests/unit_tests/blas/level2/her2.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::her2(main_queue, upper_lower, n, alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; @@ -102,13 +102,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her2, upper_lower, - n, alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her2, + upper_lower, n, alpha, x_buffer, incx, y_buffer, incy, + A_buffer, lda); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her2, upper_lower, n, - alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her2, upper_lower, + n, alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; default: break; } @@ -128,7 +129,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, n, n, lda, n, std::cout); return (int)good; @@ -153,6 +154,8 @@ TEST_P(Her2Tests, ComplexSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 1, 42)); } TEST_P(Her2Tests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 3, 42)); @@ -170,7 +173,7 @@ TEST_P(Her2Tests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Her2TestSuite, Her2Tests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/her2_usm.cpp b/tests/unit_tests/blas/level2/her2_usm.cpp index 70b2a153e..c732331ee 100644 --- a/tests/unit_tests/blas/level2/her2_usm.cpp +++ b/tests/unit_tests/blas/level2/her2_usm.cpp @@ -91,7 +91,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::her2(main_queue, upper_lower, n, alpha, x.data(), incx, y.data(), incy, A.data(), lda, dependencies); @@ -106,15 +106,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her2, upper_lower, - n, alpha, x.data(), incx, y.data(), incy, A.data(), lda, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her2, + upper_lower, n, alpha, x.data(), incx, y.data(), incy, + A.data(), lda, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her2, upper_lower, n, - alpha, x.data(), incx, y.data(), incy, A.data(), lda, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her2, upper_lower, + n, alpha, x.data(), incx, y.data(), incy, A.data(), lda, + dependencies); break; default: break; } @@ -160,6 +160,8 @@ TEST_P(Her2UsmTests, ComplexSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 1, 42)); } TEST_P(Her2UsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 3, 42)); @@ -177,7 +179,7 @@ TEST_P(Her2UsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Her2UsmTestSuite, Her2UsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/her_usm.cpp b/tests/unit_tests/blas/level2/her_usm.cpp index 208e153a0..9e1f5099e 100644 --- a/tests/unit_tests/blas/level2/her_usm.cpp +++ b/tests/unit_tests/blas/level2/her_usm.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::her( main_queue, upper_lower, n, alpha, x.data(), incx, A.data(), lda, dependencies); break; @@ -102,13 +102,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her, upper_lower, n, - alpha, x.data(), incx, A.data(), lda, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her, + upper_lower, n, alpha, x.data(), incx, A.data(), lda, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her, upper_lower, n, - alpha, x.data(), incx, A.data(), lda, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her, upper_lower, + n, alpha, x.data(), incx, A.data(), lda, dependencies); break; default: break; } @@ -160,6 +161,8 @@ TEST_P(HerUsmTests, ComplexSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 42))); } TEST_P(HerUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP( (test, double>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -183,7 +186,7 @@ TEST_P(HerUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HerUsmTestSuite, HerUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hpmv.cpp b/tests/unit_tests/blas/level2/hpmv.cpp index c6b3d8485..69e6ea9b2 100644 --- a/tests/unit_tests/blas/level2/hpmv.cpp +++ b/tests/unit_tests/blas/level2/hpmv.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::hpmv(main_queue, upper_lower, n, alpha, A_buffer, x_buffer, incx, beta, y_buffer, incy); break; @@ -101,13 +101,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpmv, upper_lower, - n, alpha, A_buffer, x_buffer, incx, beta, y_buffer, incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpmv, + upper_lower, n, alpha, A_buffer, x_buffer, incx, beta, + y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpmv, upper_lower, n, - alpha, A_buffer, x_buffer, incx, beta, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpmv, upper_lower, + n, alpha, A_buffer, x_buffer, incx, beta, y_buffer, incy); break; default: break; } @@ -127,7 +128,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, n, incy, n, std::cout); return (int)good; @@ -153,6 +154,8 @@ TEST_P(HpmvTests, ComplexSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, beta, 1, 1)); } TEST_P(HpmvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -173,7 +176,7 @@ TEST_P(HpmvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HpmvTestSuite, HpmvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hpmv_usm.cpp b/tests/unit_tests/blas/level2/hpmv_usm.cpp index cf3010f93..743194b18 100644 --- a/tests/unit_tests/blas/level2/hpmv_usm.cpp +++ b/tests/unit_tests/blas/level2/hpmv_usm.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::hpmv(main_queue, upper_lower, n, alpha, A.data(), x.data(), incx, beta, y.data(), incy, dependencies); @@ -105,15 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpmv, upper_lower, - n, alpha, A.data(), x.data(), incx, beta, y.data(), incy, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpmv, + upper_lower, n, alpha, A.data(), x.data(), incx, beta, + y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpmv, upper_lower, n, - alpha, A.data(), x.data(), incx, beta, y.data(), incy, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpmv, upper_lower, + n, alpha, A.data(), x.data(), incx, beta, y.data(), incy, + dependencies); break; default: break; } @@ -160,6 +160,8 @@ TEST_P(HpmvUsmTests, ComplexSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, beta, 1, 1)); } TEST_P(HpmvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -180,7 +182,7 @@ TEST_P(HpmvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HpmvUsmTestSuite, HpmvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hpr.cpp b/tests/unit_tests/blas/level2/hpr.cpp index 20288cfdd..b2e5548bd 100644 --- a/tests/unit_tests/blas/level2/hpr.cpp +++ b/tests/unit_tests/blas/level2/hpr.cpp @@ -87,7 +87,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::hpr(main_queue, upper_lower, n, alpha, x_buffer, incx, A_buffer); break; @@ -99,13 +99,13 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpr, upper_lower, n, - alpha, x_buffer, incx, A_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpr, + upper_lower, n, alpha, x_buffer, incx, A_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpr, upper_lower, n, - alpha, x_buffer, incx, A_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpr, upper_lower, + n, alpha, x_buffer, incx, A_buffer); break; default: break; } @@ -125,7 +125,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, n, n, n, n, std::cout); return (int)good; @@ -153,6 +153,8 @@ TEST_P(HprTests, ComplexSinglePrecision) { } TEST_P(HprTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP((test, double>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2))); @@ -172,7 +174,7 @@ TEST_P(HprTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HprTestSuite, HprTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hpr2.cpp b/tests/unit_tests/blas/level2/hpr2.cpp index 2b4d62835..e2b19e2fd 100644 --- a/tests/unit_tests/blas/level2/hpr2.cpp +++ b/tests/unit_tests/blas/level2/hpr2.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::hpr2(main_queue, upper_lower, n, alpha, x_buffer, incx, y_buffer, incy, A_buffer); break; @@ -101,13 +101,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpr2, upper_lower, - n, alpha, x_buffer, incx, y_buffer, incy, A_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpr2, + upper_lower, n, alpha, x_buffer, incx, y_buffer, incy, + A_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpr2, upper_lower, n, - alpha, x_buffer, incx, y_buffer, incy, A_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpr2, upper_lower, + n, alpha, x_buffer, incx, y_buffer, incy, A_buffer); break; default: break; } @@ -127,7 +128,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, n, n, n, n, std::cout); return (int)good; @@ -152,6 +153,8 @@ TEST_P(Hpr2Tests, ComplexSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 1)); } TEST_P(Hpr2Tests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 3)); @@ -169,7 +172,7 @@ TEST_P(Hpr2Tests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Hpr2TestSuite, Hpr2Tests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hpr2_usm.cpp b/tests/unit_tests/blas/level2/hpr2_usm.cpp index bde134ca5..6dc60dbf6 100644 --- a/tests/unit_tests/blas/level2/hpr2_usm.cpp +++ b/tests/unit_tests/blas/level2/hpr2_usm.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::hpr2(main_queue, upper_lower, n, alpha, x.data(), incx, y.data(), incy, A.data(), dependencies); @@ -105,14 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpr2, upper_lower, - n, alpha, x.data(), incx, y.data(), incy, A.data(), - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpr2, + upper_lower, n, alpha, x.data(), incx, y.data(), incy, + A.data(), dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpr2, upper_lower, n, - alpha, x.data(), incx, y.data(), incy, A.data(), dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpr2, upper_lower, + n, alpha, x.data(), incx, y.data(), incy, A.data(), + dependencies); break; default: break; } @@ -158,6 +159,8 @@ TEST_P(Hpr2UsmTests, ComplexSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 1)); } TEST_P(Hpr2UsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 3)); @@ -175,7 +178,7 @@ TEST_P(Hpr2UsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Hpr2UsmTestSuite, Hpr2UsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/hpr_usm.cpp b/tests/unit_tests/blas/level2/hpr_usm.cpp index 2101f2023..b90b0ee63 100644 --- a/tests/unit_tests/blas/level2/hpr_usm.cpp +++ b/tests/unit_tests/blas/level2/hpr_usm.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::hpr(main_queue, upper_lower, n, alpha, x.data(), incx, A.data(), dependencies); break; @@ -102,13 +102,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpr, upper_lower, n, - alpha, x.data(), incx, A.data(), dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hpr, + upper_lower, n, alpha, x.data(), incx, A.data(), + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpr, upper_lower, n, - alpha, x.data(), incx, A.data(), dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hpr, upper_lower, + n, alpha, x.data(), incx, A.data(), dependencies); break; default: break; } @@ -157,6 +158,8 @@ TEST_P(HprUsmTests, ComplexSinglePrecision) { } TEST_P(HprUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP((test, double>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2))); @@ -176,7 +179,7 @@ TEST_P(HprUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HprUsmTestSuite, HprUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/sbmv.cpp b/tests/unit_tests/blas/level2/sbmv.cpp index 6255985f4..c0347dfda 100644 --- a/tests/unit_tests/blas/level2/sbmv.cpp +++ b/tests/unit_tests/blas/level2/sbmv.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::sbmv(main_queue, upper_lower, n, k, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); @@ -102,14 +102,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::sbmv, upper_lower, - n, k, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, - incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::sbmv, + upper_lower, n, k, alpha, A_buffer, lda, x_buffer, incx, + beta, y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::sbmv, upper_lower, n, - k, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::sbmv, upper_lower, + n, k, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, + incy); break; default: break; } @@ -129,7 +130,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, n, incy, n, std::cout); return (int)good; @@ -155,6 +156,8 @@ TEST_P(SbmvTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, 5, alpha, beta, 1, 1, 42)); } TEST_P(SbmvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -173,7 +176,7 @@ TEST_P(SbmvTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SbmvTestSuite, SbmvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/sbmv_usm.cpp b/tests/unit_tests/blas/level2/sbmv_usm.cpp index abf9dfa94..4fb7d46ad 100644 --- a/tests/unit_tests/blas/level2/sbmv_usm.cpp +++ b/tests/unit_tests/blas/level2/sbmv_usm.cpp @@ -91,7 +91,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::sbmv(main_queue, upper_lower, n, k, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, dependencies); @@ -106,15 +106,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::sbmv, upper_lower, - n, k, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::sbmv, + upper_lower, n, k, alpha, A.data(), lda, x.data(), incx, + beta, y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::sbmv, upper_lower, n, - k, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::sbmv, upper_lower, + n, k, alpha, A.data(), lda, x.data(), incx, beta, y.data(), + incy, dependencies); break; default: break; } @@ -161,6 +161,8 @@ TEST_P(SbmvUsmTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, 5, alpha, beta, 1, 1, 42)); } TEST_P(SbmvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -179,7 +181,7 @@ TEST_P(SbmvUsmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SbmvUsmTestSuite, SbmvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/spmv.cpp b/tests/unit_tests/blas/level2/spmv.cpp index a445ef5c7..799e7d775 100644 --- a/tests/unit_tests/blas/level2/spmv.cpp +++ b/tests/unit_tests/blas/level2/spmv.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::spmv(main_queue, upper_lower, n, alpha, A_buffer, x_buffer, incx, beta, y_buffer, incy); break; @@ -101,13 +101,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spmv, upper_lower, - n, alpha, A_buffer, x_buffer, incx, beta, y_buffer, incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spmv, + upper_lower, n, alpha, A_buffer, x_buffer, incx, beta, + y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spmv, upper_lower, n, - alpha, A_buffer, x_buffer, incx, beta, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spmv, upper_lower, + n, alpha, A_buffer, x_buffer, incx, beta, y_buffer, incy); break; default: break; } @@ -127,7 +128,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, n, incy, n, std::cout); return (int)good; @@ -153,6 +154,8 @@ TEST_P(SpmvTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, beta, 1, 1)); } TEST_P(SpmvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -171,7 +174,7 @@ TEST_P(SpmvTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SpmvTestSuite, SpmvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/spmv_usm.cpp b/tests/unit_tests/blas/level2/spmv_usm.cpp index 67a778fe2..ae38ada4a 100644 --- a/tests/unit_tests/blas/level2/spmv_usm.cpp +++ b/tests/unit_tests/blas/level2/spmv_usm.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::spmv(main_queue, upper_lower, n, alpha, A.data(), x.data(), incx, beta, y.data(), incy, dependencies); @@ -105,15 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spmv, upper_lower, - n, alpha, A.data(), x.data(), incx, beta, y.data(), incy, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spmv, + upper_lower, n, alpha, A.data(), x.data(), incx, beta, + y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spmv, upper_lower, n, - alpha, A.data(), x.data(), incx, beta, y.data(), incy, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spmv, upper_lower, + n, alpha, A.data(), x.data(), incx, beta, y.data(), incy, + dependencies); break; default: break; } @@ -160,6 +160,8 @@ TEST_P(SpmvUsmTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, beta, 1, 1)); } TEST_P(SpmvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -178,7 +180,7 @@ TEST_P(SpmvUsmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SpmvUsmTestSuite, SpmvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/spr.cpp b/tests/unit_tests/blas/level2/spr.cpp index 0823a3668..4e4b5d8a9 100644 --- a/tests/unit_tests/blas/level2/spr.cpp +++ b/tests/unit_tests/blas/level2/spr.cpp @@ -86,7 +86,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::spr(main_queue, upper_lower, n, alpha, x_buffer, incx, A_buffer); break; @@ -98,13 +98,13 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spr, upper_lower, n, - alpha, x_buffer, incx, A_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spr, + upper_lower, n, alpha, x_buffer, incx, A_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spr, upper_lower, n, - alpha, x_buffer, incx, A_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spr, upper_lower, + n, alpha, x_buffer, incx, A_buffer); break; default: break; } @@ -124,7 +124,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, n, n, n, n, std::cout); return (int)good; @@ -149,6 +149,8 @@ TEST_P(SprTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1)); } TEST_P(SprTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2)); @@ -166,7 +168,7 @@ TEST_P(SprTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SprTestSuite, SprTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/spr2.cpp b/tests/unit_tests/blas/level2/spr2.cpp index 82620a325..d9d00a4e8 100644 --- a/tests/unit_tests/blas/level2/spr2.cpp +++ b/tests/unit_tests/blas/level2/spr2.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::spr2(main_queue, upper_lower, n, alpha, x_buffer, incx, y_buffer, incy, A_buffer); break; @@ -101,13 +101,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spr2, upper_lower, - n, alpha, x_buffer, incx, y_buffer, incy, A_buffer); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spr2, + upper_lower, n, alpha, x_buffer, incx, y_buffer, incy, + A_buffer); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spr2, upper_lower, n, - alpha, x_buffer, incx, y_buffer, incy, A_buffer); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spr2, upper_lower, + n, alpha, x_buffer, incx, y_buffer, incy, A_buffer); break; default: break; } @@ -127,7 +128,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, n, n, n, n, std::cout); return (int)good; @@ -152,6 +153,8 @@ TEST_P(Spr2Tests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 1)); } TEST_P(Spr2Tests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 3)); @@ -169,7 +172,7 @@ TEST_P(Spr2Tests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Spr2TestSuite, Spr2Tests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/spr2_usm.cpp b/tests/unit_tests/blas/level2/spr2_usm.cpp index e8c11a281..683288775 100644 --- a/tests/unit_tests/blas/level2/spr2_usm.cpp +++ b/tests/unit_tests/blas/level2/spr2_usm.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::spr2(main_queue, upper_lower, n, alpha, x.data(), incx, y.data(), incy, A.data(), dependencies); @@ -105,14 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spr2, upper_lower, - n, alpha, x.data(), incx, y.data(), incy, A.data(), - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spr2, + upper_lower, n, alpha, x.data(), incx, y.data(), incy, + A.data(), dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spr2, upper_lower, n, - alpha, x.data(), incx, y.data(), incy, A.data(), dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spr2, upper_lower, + n, alpha, x.data(), incx, y.data(), incy, A.data(), + dependencies); break; default: break; } @@ -158,6 +159,8 @@ TEST_P(Spr2UsmTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 1)); } TEST_P(Spr2UsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 3)); @@ -175,7 +178,7 @@ TEST_P(Spr2UsmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Spr2UsmTestSuite, Spr2UsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/spr_usm.cpp b/tests/unit_tests/blas/level2/spr_usm.cpp index e053493b0..3a23a33b4 100644 --- a/tests/unit_tests/blas/level2/spr_usm.cpp +++ b/tests/unit_tests/blas/level2/spr_usm.cpp @@ -88,7 +88,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::spr(main_queue, upper_lower, n, alpha, x.data(), incx, A.data(), dependencies); break; @@ -101,13 +101,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spr, upper_lower, n, - alpha, x.data(), incx, A.data(), dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::spr, + upper_lower, n, alpha, x.data(), incx, A.data(), + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spr, upper_lower, n, - alpha, x.data(), incx, A.data(), dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::spr, upper_lower, + n, alpha, x.data(), incx, A.data(), dependencies); break; default: break; } @@ -153,6 +154,8 @@ TEST_P(SprUsmTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1)); } TEST_P(SprUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2)); @@ -170,7 +173,7 @@ TEST_P(SprUsmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SprUsmTestSuite, SprUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/symv.cpp b/tests/unit_tests/blas/level2/symv.cpp index 183123d0f..a22e48ff7 100644 --- a/tests/unit_tests/blas/level2/symv.cpp +++ b/tests/unit_tests/blas/level2/symv.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::symv(main_queue, upper_lower, n, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); break; @@ -101,13 +101,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::symv, upper_lower, - n, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::symv, + upper_lower, n, alpha, A_buffer, lda, x_buffer, incx, beta, + y_buffer, incy); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::symv, upper_lower, n, - alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, incy); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::symv, upper_lower, + n, alpha, A_buffer, lda, x_buffer, incx, beta, y_buffer, + incy); break; default: break; } @@ -127,7 +129,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto y_accessor = y_buffer.template get_host_access(read_only); + auto y_accessor = y_buffer.get_host_access(read_only); bool good = check_equal_vector(y_accessor, y_ref, n, incy, n, std::cout); return (int)good; @@ -153,6 +155,8 @@ TEST_P(SymvTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, beta, 1, 1, 42)); } TEST_P(SymvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -171,7 +175,7 @@ TEST_P(SymvTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SymvTestSuite, SymvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/symv_usm.cpp b/tests/unit_tests/blas/level2/symv_usm.cpp index b9b0d3595..f33c0d25f 100644 --- a/tests/unit_tests/blas/level2/symv_usm.cpp +++ b/tests/unit_tests/blas/level2/symv_usm.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::symv(main_queue, upper_lower, n, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, dependencies); @@ -105,15 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::symv, upper_lower, - n, alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::symv, + upper_lower, n, alpha, A.data(), lda, x.data(), incx, beta, + y.data(), incy, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::symv, upper_lower, n, - alpha, A.data(), lda, x.data(), incx, beta, y.data(), incy, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::symv, upper_lower, + n, alpha, A.data(), lda, x.data(), incx, beta, y.data(), + incy, dependencies); break; default: break; } @@ -160,6 +160,8 @@ TEST_P(SymvUsmTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, beta, 1, 1, 42)); } TEST_P(SymvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -178,7 +180,7 @@ TEST_P(SymvUsmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SymvUsmTestSuite, SymvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/syr.cpp b/tests/unit_tests/blas/level2/syr.cpp index e8f47b546..6b305582b 100644 --- a/tests/unit_tests/blas/level2/syr.cpp +++ b/tests/unit_tests/blas/level2/syr.cpp @@ -86,7 +86,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::syr(main_queue, upper_lower, n, alpha, x_buffer, incx, A_buffer, lda); break; @@ -98,13 +98,13 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr, upper_lower, n, - alpha, x_buffer, incx, A_buffer, lda); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr, + upper_lower, n, alpha, x_buffer, incx, A_buffer, lda); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr, upper_lower, n, - alpha, x_buffer, incx, A_buffer, lda); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr, upper_lower, + n, alpha, x_buffer, incx, A_buffer, lda); break; default: break; } @@ -124,7 +124,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, n, n, lda, n, std::cout); return (int)good; @@ -149,6 +149,8 @@ TEST_P(SyrTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 42)); } TEST_P(SyrTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 42)); @@ -166,7 +168,7 @@ TEST_P(SyrTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SyrTestSuite, SyrTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/syr2.cpp b/tests/unit_tests/blas/level2/syr2.cpp index 1b3db23b5..5da1e0106 100644 --- a/tests/unit_tests/blas/level2/syr2.cpp +++ b/tests/unit_tests/blas/level2/syr2.cpp @@ -89,7 +89,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::syr2(main_queue, upper_lower, n, alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; @@ -101,13 +101,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr2, upper_lower, - n, alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr2, + upper_lower, n, alpha, x_buffer, incx, y_buffer, incy, + A_buffer, lda); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr2, upper_lower, n, - alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr2, upper_lower, + n, alpha, x_buffer, incx, y_buffer, incy, A_buffer, lda); break; default: break; } @@ -127,7 +128,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto A_accessor = A_buffer.template get_host_access(read_only); + auto A_accessor = A_buffer.get_host_access(read_only); bool good = check_equal_matrix(A_accessor, A_ref, layout, n, n, lda, n, std::cout); return (int)good; @@ -152,6 +153,8 @@ TEST_P(Syr2Tests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 1, 42)); } TEST_P(Syr2Tests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 3, 42)); @@ -169,7 +172,7 @@ TEST_P(Syr2Tests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Syr2TestSuite, Syr2Tests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/syr2_usm.cpp b/tests/unit_tests/blas/level2/syr2_usm.cpp index 8ca78d9b1..a1e2cba7d 100644 --- a/tests/unit_tests/blas/level2/syr2_usm.cpp +++ b/tests/unit_tests/blas/level2/syr2_usm.cpp @@ -90,7 +90,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::syr2(main_queue, upper_lower, n, alpha, x.data(), incx, y.data(), incy, A.data(), lda, dependencies); @@ -105,15 +105,15 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr2, upper_lower, - n, alpha, x.data(), incx, y.data(), incy, A.data(), lda, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr2, + upper_lower, n, alpha, x.data(), incx, y.data(), incy, + A.data(), lda, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr2, upper_lower, n, - alpha, x.data(), incx, y.data(), incy, A.data(), lda, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr2, upper_lower, + n, alpha, x.data(), incx, y.data(), incy, A.data(), lda, + dependencies); break; default: break; } @@ -159,6 +159,8 @@ TEST_P(Syr2UsmTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 1, 42)); } TEST_P(Syr2UsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 3, 42)); @@ -176,7 +178,7 @@ TEST_P(Syr2UsmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Syr2UsmTestSuite, Syr2UsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/syr_usm.cpp b/tests/unit_tests/blas/level2/syr_usm.cpp index cbc1e1d89..5a9f5034d 100644 --- a/tests/unit_tests/blas/level2/syr_usm.cpp +++ b/tests/unit_tests/blas/level2/syr_usm.cpp @@ -88,7 +88,7 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::syr( main_queue, upper_lower, n, alpha, x.data(), incx, A.data(), lda, dependencies); break; @@ -101,13 +101,14 @@ int test(device *dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr, upper_lower, n, - alpha, x.data(), incx, A.data(), lda, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr, + upper_lower, n, alpha, x.data(), incx, A.data(), lda, + dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr, upper_lower, n, - alpha, x.data(), incx, A.data(), lda, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr, upper_lower, + n, alpha, x.data(), incx, A.data(), lda, dependencies); break; default: break; } @@ -153,6 +154,8 @@ TEST_P(SyrUsmTests, RealSinglePrecision) { oneapi::mkl::uplo::upper, 30, alpha, 1, 42)); } TEST_P(SyrUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, 30, alpha, 2, 42)); @@ -170,7 +173,7 @@ TEST_P(SyrUsmTests, RealDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SyrUsmTestSuite, SyrUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/tbmv.cpp b/tests/unit_tests/blas/level2/tbmv.cpp index 4087e92be..554082a01 100644 --- a/tests/unit_tests/blas/level2/tbmv.cpp +++ b/tests/unit_tests/blas/level2/tbmv.cpp @@ -89,7 +89,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::tbmv(main_queue, upper_lower, transa, unit_nonunit, n, k, A_buffer, lda, x_buffer, incx); break; @@ -101,13 +101,14 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tbmv, upper_lower, - transa, unit_nonunit, n, k, A_buffer, lda, x_buffer, incx); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tbmv, + upper_lower, transa, unit_nonunit, n, k, A_buffer, lda, + x_buffer, incx); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tbmv, upper_lower, - transa, unit_nonunit, n, k, A_buffer, lda, x_buffer, incx); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tbmv, upper_lower, + transa, unit_nonunit, n, k, A_buffer, lda, x_buffer, incx); break; default: break; } @@ -127,7 +128,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto x_accessor = x_buffer.template get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good = check_equal_vector(x_accessor, x_ref, n, incx, n, std::cout); return (int)good; @@ -163,6 +164,8 @@ TEST_P(TbmvTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 5, 2, 42)); } TEST_P(TbmvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 5, 2, 42)); @@ -227,6 +230,8 @@ TEST_P(TbmvTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 5, 2, 42)); } TEST_P(TbmvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 5, 2, 42)); @@ -267,7 +272,7 @@ TEST_P(TbmvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TbmvTestSuite, TbmvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/tbmv_usm.cpp b/tests/unit_tests/blas/level2/tbmv_usm.cpp index 2811c7228..808c5d1c3 100644 --- a/tests/unit_tests/blas/level2/tbmv_usm.cpp +++ b/tests/unit_tests/blas/level2/tbmv_usm.cpp @@ -91,7 +91,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::tbmv(main_queue, upper_lower, transa, unit_nonunit, n, k, A.data(), lda, x.data(), incx, dependencies); @@ -106,15 +106,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tbmv, upper_lower, - transa, unit_nonunit, n, k, A.data(), lda, x.data(), incx, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tbmv, + upper_lower, transa, unit_nonunit, n, k, A.data(), lda, + x.data(), incx, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tbmv, upper_lower, - transa, unit_nonunit, n, k, A.data(), lda, x.data(), incx, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tbmv, upper_lower, + transa, unit_nonunit, n, k, A.data(), lda, x.data(), incx, + dependencies); break; default: break; } @@ -171,6 +171,8 @@ TEST_P(TbmvUsmTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 5, 2, 42)); } TEST_P(TbmvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 5, 2, 42)); @@ -235,6 +237,8 @@ TEST_P(TbmvUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 5, 2, 42)); } TEST_P(TbmvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 5, 2, 42)); @@ -275,7 +279,7 @@ TEST_P(TbmvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TbmvUsmTestSuite, TbmvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/tbsv.cpp b/tests/unit_tests/blas/level2/tbsv.cpp index 513cf2af8..e653105e8 100644 --- a/tests/unit_tests/blas/level2/tbsv.cpp +++ b/tests/unit_tests/blas/level2/tbsv.cpp @@ -89,7 +89,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::tbsv(main_queue, upper_lower, transa, unit_nonunit, n, k, A_buffer, lda, x_buffer, incx); break; @@ -101,13 +101,14 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tbsv, upper_lower, - transa, unit_nonunit, n, k, A_buffer, lda, x_buffer, incx); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tbsv, + upper_lower, transa, unit_nonunit, n, k, A_buffer, lda, + x_buffer, incx); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tbsv, upper_lower, - transa, unit_nonunit, n, k, A_buffer, lda, x_buffer, incx); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tbsv, upper_lower, + transa, unit_nonunit, n, k, A_buffer, lda, x_buffer, incx); break; default: break; } @@ -127,7 +128,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto x_accessor = x_buffer.template get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good = check_equal_trsv_vector(x_accessor, x_ref, n, incx, n, std::cout); return (int)good; @@ -163,6 +164,8 @@ TEST_P(TbsvTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 5, 2, 42)); } TEST_P(TbsvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 5, 2, 42)); @@ -227,6 +230,8 @@ TEST_P(TbsvTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 5, 2, 42)); } TEST_P(TbsvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 5, 2, 42)); @@ -267,7 +272,7 @@ TEST_P(TbsvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TbsvTestSuite, TbsvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/tbsv_usm.cpp b/tests/unit_tests/blas/level2/tbsv_usm.cpp index e6afe6146..1b77997eb 100644 --- a/tests/unit_tests/blas/level2/tbsv_usm.cpp +++ b/tests/unit_tests/blas/level2/tbsv_usm.cpp @@ -91,7 +91,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::tbsv(main_queue, upper_lower, transa, unit_nonunit, n, k, A.data(), lda, x.data(), incx, dependencies); @@ -106,15 +106,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tbsv, upper_lower, - transa, unit_nonunit, n, k, A.data(), lda, x.data(), incx, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tbsv, + upper_lower, transa, unit_nonunit, n, k, A.data(), lda, + x.data(), incx, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tbsv, upper_lower, - transa, unit_nonunit, n, k, A.data(), lda, x.data(), incx, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tbsv, upper_lower, + transa, unit_nonunit, n, k, A.data(), lda, x.data(), incx, + dependencies); break; default: break; } @@ -171,6 +171,8 @@ TEST_P(TbsvUsmTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 5, 2, 42)); } TEST_P(TbsvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 5, 2, 42)); @@ -235,6 +237,8 @@ TEST_P(TbsvUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 5, 2, 42)); } TEST_P(TbsvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 5, 2, 42)); @@ -275,7 +279,7 @@ TEST_P(TbsvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TbsvUsmTestSuite, TbsvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/tpmv.cpp b/tests/unit_tests/blas/level2/tpmv.cpp index a481edac1..ce45279bb 100644 --- a/tests/unit_tests/blas/level2/tpmv.cpp +++ b/tests/unit_tests/blas/level2/tpmv.cpp @@ -87,7 +87,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::tpmv(main_queue, upper_lower, transa, unit_nonunit, n, A_buffer, x_buffer, incx); break; @@ -99,13 +99,14 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tpmv, upper_lower, - transa, unit_nonunit, n, A_buffer, x_buffer, incx); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tpmv, + upper_lower, transa, unit_nonunit, n, A_buffer, x_buffer, + incx); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tpmv, upper_lower, - transa, unit_nonunit, n, A_buffer, x_buffer, incx); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tpmv, upper_lower, + transa, unit_nonunit, n, A_buffer, x_buffer, incx); break; default: break; } @@ -125,7 +126,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto x_accessor = x_buffer.template get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good = check_equal_vector(x_accessor, x_ref, n, incx, n, std::cout); return (int)good; @@ -161,6 +162,8 @@ TEST_P(TpmvTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 2)); } TEST_P(TpmvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2)); @@ -225,6 +228,8 @@ TEST_P(TpmvTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 2)); } TEST_P(TpmvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2)); @@ -265,7 +270,7 @@ TEST_P(TpmvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TpmvTestSuite, TpmvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/tpmv_usm.cpp b/tests/unit_tests/blas/level2/tpmv_usm.cpp index ff17b4504..74ebc2502 100644 --- a/tests/unit_tests/blas/level2/tpmv_usm.cpp +++ b/tests/unit_tests/blas/level2/tpmv_usm.cpp @@ -89,7 +89,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::tpmv(main_queue, upper_lower, transa, unit_nonunit, n, A.data(), x.data(), incx, dependencies); @@ -104,13 +104,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tpmv, upper_lower, - transa, unit_nonunit, n, A.data(), x.data(), incx, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tpmv, + upper_lower, transa, unit_nonunit, n, A.data(), x.data(), + incx, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tpmv, upper_lower, - transa, unit_nonunit, n, A.data(), x.data(), incx, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tpmv, upper_lower, + transa, unit_nonunit, n, A.data(), x.data(), incx, + dependencies); break; default: break; } @@ -167,6 +169,8 @@ TEST_P(TpmvUsmTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 2)); } TEST_P(TpmvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2)); @@ -231,6 +235,8 @@ TEST_P(TpmvUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 2)); } TEST_P(TpmvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2)); @@ -271,7 +277,7 @@ TEST_P(TpmvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TpmvUsmTestSuite, TpmvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/tpsv.cpp b/tests/unit_tests/blas/level2/tpsv.cpp index 2b825f92e..2a12ab1da 100644 --- a/tests/unit_tests/blas/level2/tpsv.cpp +++ b/tests/unit_tests/blas/level2/tpsv.cpp @@ -87,7 +87,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::tpsv(main_queue, upper_lower, transa, unit_nonunit, n, A_buffer, x_buffer, incx); break; @@ -99,13 +99,14 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tpsv, upper_lower, - transa, unit_nonunit, n, A_buffer, x_buffer, incx); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tpsv, + upper_lower, transa, unit_nonunit, n, A_buffer, x_buffer, + incx); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tpsv, upper_lower, - transa, unit_nonunit, n, A_buffer, x_buffer, incx); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tpsv, upper_lower, + transa, unit_nonunit, n, A_buffer, x_buffer, incx); break; default: break; } @@ -125,7 +126,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto x_accessor = x_buffer.template get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good = check_equal_trsv_vector(x_accessor, x_ref, n, incx, n, std::cout); return (int)good; @@ -161,6 +162,8 @@ TEST_P(TpsvTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 2)); } TEST_P(TpsvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2)); @@ -225,6 +228,8 @@ TEST_P(TpsvTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 2)); } TEST_P(TpsvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2)); @@ -265,7 +270,7 @@ TEST_P(TpsvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TpsvTestSuite, TpsvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/tpsv_usm.cpp b/tests/unit_tests/blas/level2/tpsv_usm.cpp index b9a22cbdc..bcb676843 100644 --- a/tests/unit_tests/blas/level2/tpsv_usm.cpp +++ b/tests/unit_tests/blas/level2/tpsv_usm.cpp @@ -89,7 +89,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::tpsv(main_queue, upper_lower, transa, unit_nonunit, n, A.data(), x.data(), incx, dependencies); @@ -104,13 +104,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tpsv, upper_lower, - transa, unit_nonunit, n, A.data(), x.data(), incx, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::tpsv, + upper_lower, transa, unit_nonunit, n, A.data(), x.data(), + incx, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tpsv, upper_lower, - transa, unit_nonunit, n, A.data(), x.data(), incx, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::tpsv, upper_lower, + transa, unit_nonunit, n, A.data(), x.data(), incx, + dependencies); break; default: break; } @@ -167,6 +169,8 @@ TEST_P(TpsvUsmTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 2)); } TEST_P(TpsvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2)); @@ -231,6 +235,8 @@ TEST_P(TpsvUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 2)); } TEST_P(TpsvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2)); @@ -271,7 +277,7 @@ TEST_P(TpsvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TpsvUsmTestSuite, TpsvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/trmv.cpp b/tests/unit_tests/blas/level2/trmv.cpp index f91c4b4af..8dfc517eb 100644 --- a/tests/unit_tests/blas/level2/trmv.cpp +++ b/tests/unit_tests/blas/level2/trmv.cpp @@ -87,7 +87,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::trmv(main_queue, upper_lower, transa, unit_nonunit, n, A_buffer, lda, x_buffer, incx); break; @@ -99,13 +99,14 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trmv, upper_lower, - transa, unit_nonunit, n, A_buffer, lda, x_buffer, incx); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trmv, + upper_lower, transa, unit_nonunit, n, A_buffer, lda, + x_buffer, incx); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trmv, upper_lower, - transa, unit_nonunit, n, A_buffer, lda, x_buffer, incx); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trmv, upper_lower, + transa, unit_nonunit, n, A_buffer, lda, x_buffer, incx); break; default: break; } @@ -125,7 +126,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto x_accessor = x_buffer.template get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good = check_equal_vector(x_accessor, x_ref, n, incx, n, std::cout); return (int)good; @@ -161,6 +162,8 @@ TEST_P(TrmvTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 2, 42)); } TEST_P(TrmvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2, 42)); @@ -225,6 +228,8 @@ TEST_P(TrmvTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 2, 42)); } TEST_P(TrmvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2, 42)); @@ -265,7 +270,7 @@ TEST_P(TrmvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TrmvTestSuite, TrmvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/trmv_usm.cpp b/tests/unit_tests/blas/level2/trmv_usm.cpp index f9ad1287a..af3e4b898 100644 --- a/tests/unit_tests/blas/level2/trmv_usm.cpp +++ b/tests/unit_tests/blas/level2/trmv_usm.cpp @@ -89,7 +89,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::trmv(main_queue, upper_lower, transa, unit_nonunit, n, A.data(), lda, x.data(), incx, dependencies); @@ -104,15 +104,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trmv, upper_lower, - transa, unit_nonunit, n, A.data(), lda, x.data(), incx, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trmv, + upper_lower, transa, unit_nonunit, n, A.data(), lda, + x.data(), incx, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trmv, upper_lower, - transa, unit_nonunit, n, A.data(), lda, x.data(), incx, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trmv, upper_lower, + transa, unit_nonunit, n, A.data(), lda, x.data(), incx, + dependencies); break; default: break; } @@ -169,6 +169,8 @@ TEST_P(TrmvUsmTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 2, 42)); } TEST_P(TrmvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2, 42)); @@ -233,6 +235,8 @@ TEST_P(TrmvUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 2, 42)); } TEST_P(TrmvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2, 42)); @@ -273,7 +277,7 @@ TEST_P(TrmvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TrmvUsmTestSuite, TrmvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/trsv.cpp b/tests/unit_tests/blas/level2/trsv.cpp index 5db9dbcd1..fb1e39e06 100644 --- a/tests/unit_tests/blas/level2/trsv.cpp +++ b/tests/unit_tests/blas/level2/trsv.cpp @@ -87,7 +87,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::trsv(main_queue, upper_lower, transa, unit_nonunit, n, A_buffer, lda, x_buffer, incx); break; @@ -99,13 +99,14 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsv, upper_lower, - transa, unit_nonunit, n, A_buffer, lda, x_buffer, incx); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsv, + upper_lower, transa, unit_nonunit, n, A_buffer, lda, + x_buffer, incx); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsv, upper_lower, - transa, unit_nonunit, n, A_buffer, lda, x_buffer, incx); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsv, upper_lower, + transa, unit_nonunit, n, A_buffer, lda, x_buffer, incx); break; default: break; } @@ -125,7 +126,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto x_accessor = x_buffer.template get_host_access(read_only); + auto x_accessor = x_buffer.get_host_access(read_only); bool good = check_equal_trsv_vector(x_accessor, x_ref, n, incx, n, std::cout); return (int)good; @@ -161,6 +162,8 @@ TEST_P(TrsvTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 2, 42)); } TEST_P(TrsvTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2, 42)); @@ -225,6 +228,8 @@ TEST_P(TrsvTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 2, 42)); } TEST_P(TrsvTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2, 42)); @@ -265,7 +270,7 @@ TEST_P(TrsvTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TrsvTestSuite, TrsvTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level2/trsv_usm.cpp b/tests/unit_tests/blas/level2/trsv_usm.cpp index af4eab466..2e6242d58 100644 --- a/tests/unit_tests/blas/level2/trsv_usm.cpp +++ b/tests/unit_tests/blas/level2/trsv_usm.cpp @@ -89,7 +89,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::trsv(main_queue, upper_lower, transa, unit_nonunit, n, A.data(), lda, x.data(), incx, dependencies); @@ -104,15 +104,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsv, upper_lower, - transa, unit_nonunit, n, A.data(), lda, x.data(), incx, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsv, + upper_lower, transa, unit_nonunit, n, A.data(), lda, + x.data(), incx, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsv, upper_lower, - transa, unit_nonunit, n, A.data(), lda, x.data(), incx, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsv, upper_lower, + transa, unit_nonunit, n, A.data(), lda, x.data(), incx, + dependencies); break; default: break; } @@ -169,6 +169,8 @@ TEST_P(TrsvUsmTests, RealSinglePrecision) { oneapi::mkl::diag::nonunit, 30, 2, 42)); } TEST_P(TrsvUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2, 42)); @@ -233,6 +235,8 @@ TEST_P(TrsvUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, oneapi::mkl::diag::nonunit, 30, 2, 42)); } TEST_P(TrsvUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + EXPECT_TRUEORSKIP(test>( std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::uplo::lower, oneapi::mkl::transpose::nontrans, oneapi::mkl::diag::unit, 30, 2, 42)); @@ -273,7 +277,7 @@ TEST_P(TrsvUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TrsvUsmTestSuite, TrsvUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/gemm.cpp b/tests/unit_tests/blas/level3/gemm.cpp index 5c648264d..564700b16 100644 --- a/tests/unit_tests/blas/level3/gemm.cpp +++ b/tests/unit_tests/blas/level3/gemm.cpp @@ -97,7 +97,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::gemm(main_queue, transa, transb, m, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, ldc); @@ -111,15 +111,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm, transa, - transb, m, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, - C_buffer, ldc); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm, transa, + transb, m, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, + C_buffer, ldc); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm, transa, transb, - m, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, - ldc); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm, transa, + transb, m, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, + C_buffer, ldc); break; default: break; } @@ -139,7 +139,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, } // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_ref, layout, m, n, ldc, 10 * k, std::cout); return (int)good; @@ -220,6 +220,8 @@ TEST_P(GemmTests, RealSinglePrecision) { } TEST_P(GemmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP((test( @@ -269,6 +271,8 @@ TEST_P(GemmTests, ComplexSinglePrecision) { } TEST_P(GemmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP((test, std::complex>( @@ -302,7 +306,7 @@ TEST_P(GemmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GemmTestSuite, GemmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/gemm_usm.cpp b/tests/unit_tests/blas/level3/gemm_usm.cpp index 6566de924..9d5d8d048 100644 --- a/tests/unit_tests/blas/level3/gemm_usm.cpp +++ b/tests/unit_tests/blas/level3/gemm_usm.cpp @@ -97,7 +97,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::gemm(main_queue, transa, transb, m, n, k, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc, dependencies); @@ -112,15 +112,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::transpose transa, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm, transa, - transb, m, n, k, alpha, A.data(), lda, B.data(), ldb, beta, - C.data(), ldc, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::gemm, transa, + transb, m, n, k, alpha, A.data(), lda, B.data(), ldb, beta, + C.data(), ldc, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm, transa, transb, - m, n, k, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), - ldc, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::gemm, transa, + transb, m, n, k, alpha, A.data(), lda, B.data(), ldb, beta, + C.data(), ldc, dependencies); break; default: break; } @@ -219,6 +219,8 @@ TEST_P(GemmUsmTests, RealSinglePrecision) { } TEST_P(GemmUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP((test( @@ -268,6 +270,8 @@ TEST_P(GemmUsmTests, ComplexSinglePrecision) { } TEST_P(GemmUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP((test, std::complex>( @@ -301,7 +305,7 @@ TEST_P(GemmUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(GemmUsmTestSuite, GemmUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/hemm.cpp b/tests/unit_tests/blas/level3/hemm.cpp index c74933e01..ce050e97d 100644 --- a/tests/unit_tests/blas/level3/hemm.cpp +++ b/tests/unit_tests/blas/level3/hemm.cpp @@ -96,7 +96,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::hemm(main_queue, left_right, upper_lower, m, n, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, ldc); @@ -110,15 +110,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hemm, left_right, - upper_lower, m, n, alpha, A_buffer, lda, B_buffer, ldb, beta, - C_buffer, ldc); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hemm, + left_right, upper_lower, m, n, alpha, A_buffer, lda, + B_buffer, ldb, beta, C_buffer, ldc); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hemm, left_right, - upper_lower, m, n, alpha, A_buffer, lda, B_buffer, ldb, beta, - C_buffer, ldc); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hemm, left_right, + upper_lower, m, n, alpha, A_buffer, lda, B_buffer, ldb, + beta, C_buffer, ldc); break; default: break; } @@ -138,7 +138,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, } // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_ref, layout, m, n, ldc, 10 * std::max(m, n), std::cout); @@ -165,6 +165,8 @@ TEST_P(HemmTests, ComplexSinglePrecision) { 72, 27, 101, 102, 103, alpha, beta)); } TEST_P(HemmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -183,7 +185,7 @@ TEST_P(HemmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HemmTestSuite, HemmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/hemm_usm.cpp b/tests/unit_tests/blas/level3/hemm_usm.cpp index 79cbd175b..eafb06ea5 100644 --- a/tests/unit_tests/blas/level3/hemm_usm.cpp +++ b/tests/unit_tests/blas/level3/hemm_usm.cpp @@ -95,7 +95,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::hemm(main_queue, left_right, upper_lower, m, n, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc, dependencies); @@ -110,15 +110,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hemm, left_right, - upper_lower, m, n, alpha, A.data(), lda, B.data(), ldb, beta, - C.data(), ldc, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::hemm, + left_right, upper_lower, m, n, alpha, A.data(), lda, + B.data(), ldb, beta, C.data(), ldc, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hemm, left_right, - upper_lower, m, n, alpha, A.data(), lda, B.data(), ldb, beta, - C.data(), ldc, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::hemm, left_right, + upper_lower, m, n, alpha, A.data(), lda, B.data(), ldb, + beta, C.data(), ldc, dependencies); break; default: break; } @@ -165,6 +165,8 @@ TEST_P(HemmUsmTests, ComplexSinglePrecision) { 72, 27, 101, 102, 103, alpha, beta)); } TEST_P(HemmUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -183,7 +185,7 @@ TEST_P(HemmUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HemmUsmTestSuite, HemmUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/her2k.cpp b/tests/unit_tests/blas/level3/her2k.cpp index 405680118..ce57041d9 100644 --- a/tests/unit_tests/blas/level3/her2k.cpp +++ b/tests/unit_tests/blas/level3/her2k.cpp @@ -97,7 +97,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::her2k(main_queue, upper_lower, trans, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, ldc); @@ -111,15 +111,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her2k, upper_lower, - trans, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, - ldc); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her2k, + upper_lower, trans, n, k, alpha, A_buffer, lda, B_buffer, + ldb, beta, C_buffer, ldc); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her2k, upper_lower, - trans, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, - ldc); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her2k, + upper_lower, trans, n, k, alpha, A_buffer, lda, B_buffer, + ldb, beta, C_buffer, ldc); break; default: break; } @@ -139,7 +139,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_ref, layout, n, n, ldc, 10 * std::max(n, k), std::cout); @@ -166,6 +166,8 @@ TEST_P(Her2kTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, 72, 27, 101, 102, 103, alpha, beta))); } TEST_P(Her2kTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); double beta(1.0); EXPECT_TRUEORSKIP((test, double>( @@ -184,7 +186,7 @@ TEST_P(Her2kTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Her2kTestSuite, Her2kTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/her2k_usm.cpp b/tests/unit_tests/blas/level3/her2k_usm.cpp index bca617c60..a4ada6cb2 100644 --- a/tests/unit_tests/blas/level3/her2k_usm.cpp +++ b/tests/unit_tests/blas/level3/her2k_usm.cpp @@ -97,7 +97,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::her2k(main_queue, upper_lower, trans, n, k, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc, dependencies); @@ -112,15 +112,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her2k, upper_lower, - trans, n, k, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), - ldc, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::her2k, + upper_lower, trans, n, k, alpha, A.data(), lda, B.data(), + ldb, beta, C.data(), ldc, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her2k, upper_lower, - trans, n, k, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), - ldc, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::her2k, + upper_lower, trans, n, k, alpha, A.data(), lda, B.data(), + ldb, beta, C.data(), ldc, dependencies); break; default: break; } @@ -167,6 +167,8 @@ TEST_P(Her2kUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, 72, 27, 101, 102, 103, alpha, beta))); } TEST_P(Her2kUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); double beta(1.0); EXPECT_TRUEORSKIP((test, double>( @@ -185,7 +187,7 @@ TEST_P(Her2kUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Her2kUsmTestSuite, Her2kUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/herk.cpp b/tests/unit_tests/blas/level3/herk.cpp index c102acc8f..f908a77b7 100644 --- a/tests/unit_tests/blas/level3/herk.cpp +++ b/tests/unit_tests/blas/level3/herk.cpp @@ -91,7 +91,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::herk(main_queue, upper_lower, trans, n, k, alpha, A_buffer, lda, beta, C_buffer, ldc); break; @@ -103,13 +103,14 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::herk, upper_lower, - trans, n, k, alpha, A_buffer, lda, beta, C_buffer, ldc); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::herk, + upper_lower, trans, n, k, alpha, A_buffer, lda, beta, + C_buffer, ldc); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::herk, upper_lower, - trans, n, k, alpha, A_buffer, lda, beta, C_buffer, ldc); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::herk, upper_lower, + trans, n, k, alpha, A_buffer, lda, beta, C_buffer, ldc); break; default: break; } @@ -129,7 +130,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_ref, layout, n, n, ldc, 10 * std::max(n, k), std::cout); @@ -156,6 +157,8 @@ TEST_P(HerkTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, 72, 27, 101, 103, alpha, beta))); } TEST_P(HerkTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP((test, double>( @@ -174,7 +177,7 @@ TEST_P(HerkTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HerkTestSuite, HerkTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/herk_usm.cpp b/tests/unit_tests/blas/level3/herk_usm.cpp index f547b78a9..470159c63 100644 --- a/tests/unit_tests/blas/level3/herk_usm.cpp +++ b/tests/unit_tests/blas/level3/herk_usm.cpp @@ -92,7 +92,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::herk(main_queue, upper_lower, trans, n, k, alpha, A.data(), lda, beta, C.data(), ldc, dependencies); @@ -107,15 +107,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::herk, upper_lower, - trans, n, k, alpha, A.data(), lda, beta, C.data(), ldc, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::herk, + upper_lower, trans, n, k, alpha, A.data(), lda, beta, + C.data(), ldc, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::herk, upper_lower, - trans, n, k, alpha, A.data(), lda, beta, C.data(), ldc, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::herk, upper_lower, + trans, n, k, alpha, A.data(), lda, beta, C.data(), ldc, + dependencies); break; default: break; } @@ -162,6 +162,8 @@ TEST_P(HerkUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::conjtrans, 72, 27, 101, 103, alpha, beta))); } TEST_P(HerkUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP((test, double>( @@ -180,7 +182,7 @@ TEST_P(HerkUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(HerkUsmTestSuite, HerkUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/symm.cpp b/tests/unit_tests/blas/level3/symm.cpp index c57e9ae41..3f6920370 100644 --- a/tests/unit_tests/blas/level3/symm.cpp +++ b/tests/unit_tests/blas/level3/symm.cpp @@ -96,7 +96,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::symm(main_queue, left_right, upper_lower, m, n, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, ldc); @@ -110,15 +110,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::symm, left_right, - upper_lower, m, n, alpha, A_buffer, lda, B_buffer, ldb, beta, - C_buffer, ldc); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::symm, + left_right, upper_lower, m, n, alpha, A_buffer, lda, + B_buffer, ldb, beta, C_buffer, ldc); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::symm, left_right, - upper_lower, m, n, alpha, A_buffer, lda, B_buffer, ldb, beta, - C_buffer, ldc); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::symm, left_right, + upper_lower, m, n, alpha, A_buffer, lda, B_buffer, ldb, + beta, C_buffer, ldc); break; default: break; } @@ -138,7 +138,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, } // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_ref, layout, m, n, ldc, 10 * std::max(m, n), std::cout); @@ -165,6 +165,8 @@ TEST_P(SymmTests, RealSinglePrecision) { 102, 103, alpha, beta)); } TEST_P(SymmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -197,6 +199,8 @@ TEST_P(SymmTests, ComplexSinglePrecision) { 72, 27, 101, 102, 103, alpha, beta)); } TEST_P(SymmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -215,7 +219,7 @@ TEST_P(SymmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SymmTestSuite, SymmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/symm_usm.cpp b/tests/unit_tests/blas/level3/symm_usm.cpp index 02150fcbd..f774e82e3 100644 --- a/tests/unit_tests/blas/level3/symm_usm.cpp +++ b/tests/unit_tests/blas/level3/symm_usm.cpp @@ -95,7 +95,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::symm(main_queue, left_right, upper_lower, m, n, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc, dependencies); @@ -110,15 +110,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::symm, left_right, - upper_lower, m, n, alpha, A.data(), lda, B.data(), ldb, beta, - C.data(), ldc, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::symm, + left_right, upper_lower, m, n, alpha, A.data(), lda, + B.data(), ldb, beta, C.data(), ldc, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::symm, left_right, - upper_lower, m, n, alpha, A.data(), lda, B.data(), ldb, beta, - C.data(), ldc, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::symm, left_right, + upper_lower, m, n, alpha, A.data(), lda, B.data(), ldb, + beta, C.data(), ldc, dependencies); break; default: break; } @@ -165,6 +165,8 @@ TEST_P(SymmUsmTests, RealSinglePrecision) { 102, 103, alpha, beta)); } TEST_P(SymmUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -197,6 +199,8 @@ TEST_P(SymmUsmTests, ComplexSinglePrecision) { 72, 27, 101, 102, 103, alpha, beta)); } TEST_P(SymmUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -215,7 +219,7 @@ TEST_P(SymmUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SymmUsmTestSuite, SymmUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/syr2k.cpp b/tests/unit_tests/blas/level3/syr2k.cpp index 86290b727..0153e9ec0 100644 --- a/tests/unit_tests/blas/level3/syr2k.cpp +++ b/tests/unit_tests/blas/level3/syr2k.cpp @@ -92,7 +92,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::syr2k(main_queue, upper_lower, trans, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, ldc); @@ -106,15 +106,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr2k, upper_lower, - trans, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, - ldc); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr2k, + upper_lower, trans, n, k, alpha, A_buffer, lda, B_buffer, + ldb, beta, C_buffer, ldc); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr2k, upper_lower, - trans, n, k, alpha, A_buffer, lda, B_buffer, ldb, beta, C_buffer, - ldc); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr2k, + upper_lower, trans, n, k, alpha, A_buffer, lda, B_buffer, + ldb, beta, C_buffer, ldc); break; default: break; } @@ -134,7 +134,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_ref, layout, n, n, ldc, 10 * std::max(n, k), std::cout); @@ -161,6 +161,8 @@ TEST_P(Syr2kTests, RealSinglePrecision) { 101, 102, 103, alpha, beta)); } TEST_P(Syr2kTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(3.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -193,6 +195,8 @@ TEST_P(Syr2kTests, ComplexSinglePrecision) { oneapi::mkl::transpose::trans, 73, 27, 101, 102, 103, alpha, beta)); } TEST_P(Syr2kTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(3.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>( @@ -211,7 +215,7 @@ TEST_P(Syr2kTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Syr2kTestSuite, Syr2kTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/syr2k_usm.cpp b/tests/unit_tests/blas/level3/syr2k_usm.cpp index d7dde1791..efa3f07d3 100644 --- a/tests/unit_tests/blas/level3/syr2k_usm.cpp +++ b/tests/unit_tests/blas/level3/syr2k_usm.cpp @@ -92,7 +92,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::syr2k(main_queue, upper_lower, trans, n, k, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), ldc, dependencies); @@ -107,15 +107,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr2k, upper_lower, - trans, n, k, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), - ldc, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syr2k, + upper_lower, trans, n, k, alpha, A.data(), lda, B.data(), + ldb, beta, C.data(), ldc, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr2k, upper_lower, - trans, n, k, alpha, A.data(), lda, B.data(), ldb, beta, C.data(), - ldc, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syr2k, + upper_lower, trans, n, k, alpha, A.data(), lda, B.data(), + ldb, beta, C.data(), ldc, dependencies); break; default: break; } @@ -162,6 +162,8 @@ TEST_P(Syr2kUsmTests, RealSinglePrecision) { 101, 102, 103, alpha, beta)); } TEST_P(Syr2kUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(3.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -194,6 +196,8 @@ TEST_P(Syr2kUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::trans, 73, 27, 101, 102, 103, alpha, beta)); } TEST_P(Syr2kUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(3.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>( @@ -212,7 +216,7 @@ TEST_P(Syr2kUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(Syr2kUsmTestSuite, Syr2kUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/syrk.cpp b/tests/unit_tests/blas/level3/syrk.cpp index 375665cff..a6b28735d 100644 --- a/tests/unit_tests/blas/level3/syrk.cpp +++ b/tests/unit_tests/blas/level3/syrk.cpp @@ -90,7 +90,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::syrk(main_queue, upper_lower, trans, n, k, alpha, A_buffer, lda, beta, C_buffer, ldc); break; @@ -102,13 +102,14 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk, upper_lower, - trans, n, k, alpha, A_buffer, lda, beta, C_buffer, ldc); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk, + upper_lower, trans, n, k, alpha, A_buffer, lda, beta, + C_buffer, ldc); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk, upper_lower, - trans, n, k, alpha, A_buffer, lda, beta, C_buffer, ldc); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk, upper_lower, + trans, n, k, alpha, A_buffer, lda, beta, C_buffer, ldc); break; default: break; } @@ -128,7 +129,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, } // Compare the results of reference implementation and DPC++ implementation. - auto C_accessor = C_buffer.template get_host_access(read_only); + auto C_accessor = C_buffer.get_host_access(read_only); bool good = check_equal_matrix(C_accessor, C_ref, layout, n, n, ldc, 10 * std::max(n, k), std::cout); @@ -155,6 +156,8 @@ TEST_P(SyrkTests, RealSinglePrecision) { 101, 103, alpha, beta)); } TEST_P(SyrkTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(3.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -187,6 +190,8 @@ TEST_P(SyrkTests, ComplexSinglePrecision) { oneapi::mkl::transpose::trans, 73, 27, 101, 103, alpha, beta)); } TEST_P(SyrkTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(3.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>( @@ -205,7 +210,7 @@ TEST_P(SyrkTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SyrkTestSuite, SyrkTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/syrk_usm.cpp b/tests/unit_tests/blas/level3/syrk_usm.cpp index 9c3a01104..e5569eb78 100644 --- a/tests/unit_tests/blas/level3/syrk_usm.cpp +++ b/tests/unit_tests/blas/level3/syrk_usm.cpp @@ -90,7 +90,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::syrk(main_queue, upper_lower, trans, n, k, alpha, A.data(), lda, beta, C.data(), ldc, dependencies); @@ -105,15 +105,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::uplo upper_lower, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk, upper_lower, - trans, n, k, alpha, A.data(), lda, beta, C.data(), ldc, - dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::syrk, + upper_lower, trans, n, k, alpha, A.data(), lda, beta, + C.data(), ldc, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk, upper_lower, - trans, n, k, alpha, A.data(), lda, beta, C.data(), ldc, - dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::syrk, upper_lower, + trans, n, k, alpha, A.data(), lda, beta, C.data(), ldc, + dependencies); break; default: break; } @@ -160,6 +160,8 @@ TEST_P(SyrkUsmTests, RealSinglePrecision) { 101, 103, alpha, beta)); } TEST_P(SyrkUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(3.0); double beta(3.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), @@ -192,6 +194,8 @@ TEST_P(SyrkUsmTests, ComplexSinglePrecision) { oneapi::mkl::transpose::trans, 73, 27, 101, 103, alpha, beta)); } TEST_P(SyrkUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(3.0, -0.5); std::complex beta(3.0, -1.5); EXPECT_TRUEORSKIP(test>( @@ -210,7 +214,7 @@ TEST_P(SyrkUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(SyrkUsmTestSuite, SyrkUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/trmm.cpp b/tests/unit_tests/blas/level3/trmm.cpp index 6255df29b..2a02aa0d1 100644 --- a/tests/unit_tests/blas/level3/trmm.cpp +++ b/tests/unit_tests/blas/level3/trmm.cpp @@ -96,7 +96,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::trmm(main_queue, left_right, upper_lower, transa, unit_nonunit, m, n, alpha, A_buffer, lda, B_buffer, ldb); @@ -110,15 +110,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trmm, left_right, - upper_lower, transa, unit_nonunit, m, n, alpha, A_buffer, lda, - B_buffer, ldb); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trmm, + left_right, upper_lower, transa, unit_nonunit, m, n, alpha, + A_buffer, lda, B_buffer, ldb); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trmm, left_right, - upper_lower, transa, unit_nonunit, m, n, alpha, A_buffer, lda, - B_buffer, ldb); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trmm, left_right, + upper_lower, transa, unit_nonunit, m, n, alpha, A_buffer, + lda, B_buffer, ldb); break; default: break; } @@ -138,7 +138,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, } // Compare the results of reference implementation and DPC++ implementation. - auto B_accessor = B_buffer.template get_host_access(read_only); + auto B_accessor = B_buffer.get_host_access(read_only); bool good = check_equal_matrix(B_accessor, B_ref, layout, m, n, ldb, 10 * std::max(m, n), std::cout); @@ -192,6 +192,8 @@ TEST_P(TrmmTests, RealSinglePrecision) { 101, 102, alpha)); } TEST_P(TrmmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::left, oneapi::mkl::uplo::lower, @@ -294,6 +296,8 @@ TEST_P(TrmmTests, ComplexSinglePrecision) { 27, 101, 102, alpha)); } TEST_P(TrmmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::left, oneapi::mkl::uplo::lower, @@ -355,7 +359,7 @@ TEST_P(TrmmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TrmmTestSuite, TrmmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/trmm_usm.cpp b/tests/unit_tests/blas/level3/trmm_usm.cpp index 06b460cd8..1fa9bbdb0 100644 --- a/tests/unit_tests/blas/level3/trmm_usm.cpp +++ b/tests/unit_tests/blas/level3/trmm_usm.cpp @@ -97,7 +97,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::trmm( main_queue, left_right, upper_lower, transa, unit_nonunit, m, n, alpha, A.data(), lda, B.data(), ldb, dependencies); @@ -112,15 +112,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trmm, left_right, - upper_lower, transa, unit_nonunit, m, n, alpha, A.data(), lda, - B.data(), ldb, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trmm, + left_right, upper_lower, transa, unit_nonunit, m, n, alpha, + A.data(), lda, B.data(), ldb, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trmm, left_right, - upper_lower, transa, unit_nonunit, m, n, alpha, A.data(), lda, - B.data(), ldb, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trmm, left_right, + upper_lower, transa, unit_nonunit, m, n, alpha, A.data(), + lda, B.data(), ldb, dependencies); break; default: break; } @@ -194,6 +194,8 @@ TEST_P(TrmmUsmTests, RealSinglePrecision) { 101, 102, alpha)); } TEST_P(TrmmUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::left, oneapi::mkl::uplo::lower, @@ -296,6 +298,8 @@ TEST_P(TrmmUsmTests, ComplexSinglePrecision) { 27, 101, 102, alpha)); } TEST_P(TrmmUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::left, oneapi::mkl::uplo::lower, @@ -357,7 +361,7 @@ TEST_P(TrmmUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TrmmUsmTestSuite, TrmmUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/trsm.cpp b/tests/unit_tests/blas/level3/trsm.cpp index 5023229e8..90b8d5c93 100644 --- a/tests/unit_tests/blas/level3/trsm.cpp +++ b/tests/unit_tests/blas/level3/trsm.cpp @@ -96,7 +96,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: oneapi::mkl::blas::column_major::trsm(main_queue, left_right, upper_lower, transa, unit_nonunit, m, n, alpha, A_buffer, lda, B_buffer, ldb); @@ -110,15 +110,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, } #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm, left_right, - upper_lower, transa, unit_nonunit, m, n, alpha, A_buffer, lda, - B_buffer, ldb); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm, + left_right, upper_lower, transa, unit_nonunit, m, n, alpha, + A_buffer, lda, B_buffer, ldb); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm, left_right, - upper_lower, transa, unit_nonunit, m, n, alpha, A_buffer, lda, - B_buffer, ldb); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm, left_right, + upper_lower, transa, unit_nonunit, m, n, alpha, A_buffer, + lda, B_buffer, ldb); break; default: break; } @@ -138,7 +138,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, } // Compare the results of reference implementation and DPC++ implementation. - auto B_accessor = B_buffer.template get_host_access(read_only); + auto B_accessor = B_buffer.get_host_access(read_only); bool good = check_equal_trsm_matrix(B_accessor, B_ref, layout, m, n, ldb, 10 * std::max(m, n), std::cout); @@ -216,6 +216,8 @@ TEST_P(TrsmTests, RealSinglePrecision) { 101, 102, alpha)); } TEST_P(TrsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::left, oneapi::mkl::uplo::lower, @@ -382,6 +384,8 @@ TEST_P(TrsmTests, ComplexSinglePrecision) { 27, 101, 102, alpha)); } TEST_P(TrsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::left, oneapi::mkl::uplo::lower, @@ -483,7 +487,7 @@ TEST_P(TrsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TrsmTestSuite, TrsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/blas/level3/trsm_usm.cpp b/tests/unit_tests/blas/level3/trsm_usm.cpp index c893c6a73..f84b0ed61 100644 --- a/tests/unit_tests/blas/level3/trsm_usm.cpp +++ b/tests/unit_tests/blas/level3/trsm_usm.cpp @@ -97,7 +97,7 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, try { #ifdef CALL_RT_API switch (layout) { - case oneapi::mkl::layout::column_major: + case oneapi::mkl::layout::col_major: done = oneapi::mkl::blas::column_major::trsm( main_queue, left_right, upper_lower, transa, unit_nonunit, m, n, alpha, A.data(), lda, B.data(), ldb, dependencies); @@ -112,15 +112,15 @@ int test(device* dev, oneapi::mkl::layout layout, oneapi::mkl::side left_right, done.wait(); #else switch (layout) { - case oneapi::mkl::layout::column_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm, left_right, - upper_lower, transa, unit_nonunit, m, n, alpha, A.data(), lda, - B.data(), ldb, dependencies); + case oneapi::mkl::layout::col_major: + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::column_major::trsm, + left_right, upper_lower, transa, unit_nonunit, m, n, alpha, + A.data(), lda, B.data(), ldb, dependencies); break; case oneapi::mkl::layout::row_major: - TEST_RUN_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm, left_right, - upper_lower, transa, unit_nonunit, m, n, alpha, A.data(), lda, - B.data(), ldb, dependencies); + TEST_RUN_BLAS_CT_SELECT(main_queue, oneapi::mkl::blas::row_major::trsm, left_right, + upper_lower, transa, unit_nonunit, m, n, alpha, A.data(), + lda, B.data(), ldb, dependencies); break; default: break; } @@ -219,6 +219,8 @@ TEST_P(TrsmUsmTests, RealSinglePrecision) { 101, 102, alpha)); } TEST_P(TrsmUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + double alpha(2.0); EXPECT_TRUEORSKIP(test(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::left, oneapi::mkl::uplo::lower, @@ -385,6 +387,8 @@ TEST_P(TrsmUsmTests, ComplexSinglePrecision) { 27, 101, 102, alpha)); } TEST_P(TrsmUsmTests, ComplexDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(std::get<0>(GetParam())); + std::complex alpha(2.0, -0.5); EXPECT_TRUEORSKIP(test>(std::get<0>(GetParam()), std::get<1>(GetParam()), oneapi::mkl::side::left, oneapi::mkl::uplo::lower, @@ -486,7 +490,7 @@ TEST_P(TrsmUsmTests, ComplexDoublePrecision) { INSTANTIATE_TEST_SUITE_P(TrsmUsmTestSuite, TrsmUsmTests, ::testing::Combine(testing::ValuesIn(devices), - testing::Values(oneapi::mkl::layout::column_major, + testing::Values(oneapi::mkl::layout::col_major, oneapi::mkl::layout::row_major)), ::LayoutDeviceNamePrint()); diff --git a/tests/unit_tests/dft/include/compute_inplace.hpp b/tests/unit_tests/dft/include/compute_inplace.hpp index c067c9945..9cc161c34 100644 --- a/tests/unit_tests/dft/include/compute_inplace.hpp +++ b/tests/unit_tests/dft/include/compute_inplace.hpp @@ -23,159 +23,96 @@ #include "compute_tester.hpp" #include -inline std::int64_t row_elements_to_conjugate_even_components(std::int64_t last_dim) { - return ((last_dim / 2) + 1) * 2; -} - -std::vector get_conjugate_even_real_component_strides( - const std::vector& sizes) { - switch (sizes.size()) { - case 1: return { 0, 1 }; - case 2: return { 0, 2 * (sizes[1] / 2 + 1), 1 }; - case 3: return { 0, 2 * sizes[1] * (sizes[2] / 2 + 1), 2 * (sizes[2] / 2 + 1), 1 }; - default: - throw oneapi::mkl::unimplemented( - "compute_inplace", __FUNCTION__, - "not implemented for " + std::to_string(sizes.size()) + " dimensions"); - return {}; - } -} - -template -std::vector get_conjugate_even_ref(const std::vector& sizes, std::int64_t batches, - std::vector> output_ref) { - const std::int64_t conjugate_even_last_dim = - row_elements_to_conjugate_even_components(sizes.back()); - const std::int64_t rows = - std::accumulate(sizes.begin(), sizes.end() - 1, batches, std::multiplies<>{}); - std::vector conjugate_even_ref(rows * conjugate_even_last_dim); - for (int j = 0; j < rows; j++) { - for (int i = 0; i < conjugate_even_last_dim; i += 2) { - conjugate_even_ref[j * conjugate_even_last_dim + i] = - output_ref[j * sizes.back() + i / 2].real(); - conjugate_even_ref[j * conjugate_even_last_dim + i + 1] = - output_ref[j * sizes.back() + i / 2].imag(); - } - } - return conjugate_even_ref; -} - -template -void copy_strided(const std::vector& sizes, const std::vector& input, - std::vector& output) { - auto in_iter = input.cbegin(); - auto out_iter = output.begin(); - const auto row_len = sizes.back(); - const auto conjugate_row_len = row_elements_to_conjugate_even_components(row_len); - while (in_iter < input.cend()) { - std::copy(in_iter, in_iter + row_len, out_iter); - in_iter += row_len; - out_iter += conjugate_row_len; - } -} - template int DFT_Test::test_in_place_buffer() { if (!init(MemoryAccessModel::buffer)) { return test_skipped; } - const std::int64_t container_size_total = - domain == oneapi::mkl::dft::domain::REAL - ? (size_total / sizes.back()) * - (row_elements_to_conjugate_even_components(sizes.back())) - : size_total; - const std::int64_t container_size_per_transform = container_size_total / batches; - const std::int64_t backward_elements = domain == oneapi::mkl::dft::domain::REAL - ? container_size_per_transform / 2 - : container_size_per_transform; + auto modified_strides_fwd = this->strides_fwd; + auto modified_strides_bwd = this->strides_bwd; + if (domain == oneapi::mkl::dft::domain::REAL) { + // both input and output strides must be set + auto default_conjuate_strides = get_conjugate_even_complex_strides(sizes); + std::ptrdiff_t rank = static_cast(sizes.size()); + + if (modified_strides_fwd.size() == 0) { + modified_strides_fwd = std::vector( + default_conjuate_strides.begin(), default_conjuate_strides.begin() + rank + 1); + std::transform(modified_strides_fwd.begin() + 1, modified_strides_fwd.begin() + rank, + modified_strides_fwd.begin() + 1, [](std::int64_t& s) { return 2 * s; }); + } + if (modified_strides_bwd.size() == 0) { + modified_strides_bwd = std::vector( + default_conjuate_strides.begin(), default_conjuate_strides.begin() + rank + 1); + } + } + else { + // General consistency requirements for in-place complex domain transforms require that strides are the same forward and backward. + modified_strides_fwd = modified_strides_bwd; + } + + auto [forward_distance, backward_distance] = + get_default_distances(sizes, modified_strides_fwd, modified_strides_bwd); + auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>()); descriptor_t descriptor{ sizes }; descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::INPLACE); - descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / forward_elements)); - descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - container_size_per_transform); - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_elements); - descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / forward_elements)); - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - const auto real_strides = get_conjugate_even_real_component_strides(sizes); - const auto complex_strides = get_conjugate_even_complex_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, real_strides.data()); - descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - complex_strides.data()); + descriptor.set_value(oneapi::mkl::dft::config_param::CONJUGATE_EVEN_STORAGE, + oneapi::mkl::dft::config_value::COMPLEX_COMPLEX); + descriptor.set_value(oneapi::mkl::dft::config_param::PACKED_FORMAT, + oneapi::mkl::dft::config_value::CCE_FORMAT); } - - commit_descriptor(descriptor, sycl_queue); - - std::vector inout_host(container_size_total, 0); - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - copy_strided(sizes, input, inout_host); + descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance); + if (modified_strides_fwd.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, + modified_strides_fwd.data()); } - else { - std::copy(input.begin(), input.end(), inout_host.begin()); + if (modified_strides_bwd.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, + modified_strides_bwd.data()); } + commit_descriptor(descriptor, sycl_queue); + + std::vector inout_host( + strided_copy(input, sizes, modified_strides_fwd, batches, forward_distance)); + int real_multiplier = (domain == oneapi::mkl::dft::domain::REAL ? 2 : 1); + inout_host.resize( + cast_unsigned(std::max(forward_distance, real_multiplier * backward_distance) * batches + + get_default(modified_strides_bwd, 0, 0L) * real_multiplier)); { sycl::buffer inout_buf{ inout_host }; - try { - oneapi::mkl::dft::compute_forward(descriptor, inout_buf); - } - catch (oneapi::mkl::unimplemented& e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; - } + oneapi::mkl::dft::compute_forward(descriptor, inout_buf); { - auto acc_host = inout_buf.template get_host_access(); - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - std::vector conjugate_even_ref = - get_conjugate_even_ref(sizes, batches, out_host_ref); - EXPECT_TRUE(check_equal_vector(acc_host.get_pointer(), conjugate_even_ref.data(), - inout_host.size(), abs_error_margin, - rel_error_margin, std::cout)); - } - else { - EXPECT_TRUE(check_equal_vector(acc_host.get_pointer(), out_host_ref.data(), - inout_host.size(), abs_error_margin, - rel_error_margin, std::cout)); + auto acc_host = inout_buf.get_host_access(); + auto ptr_host = reinterpret_cast(acc_host.get_pointer()); + for (std::int64_t i = 0; i < batches; i++) { + EXPECT_TRUE(check_equal_strided( + ptr_host + backward_distance * i, out_host_ref.data() + ref_distance * i, sizes, + modified_strides_bwd, abs_error_margin, rel_error_margin, std::cout)); } } - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - const auto real_strides = get_conjugate_even_real_component_strides(sizes); - const auto complex_strides = get_conjugate_even_complex_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, - complex_strides.data()); - descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - real_strides.data()); - commit_descriptor(descriptor, sycl_queue); - } - - try { - oneapi::mkl::dft::compute_backward, - FwdInputType>(descriptor, inout_buf); - } - catch (oneapi::mkl::unimplemented& e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; - } + oneapi::mkl::dft::compute_backward, + FwdInputType>(descriptor, inout_buf); } - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - for (int j = 0; j < size_total / sizes.back(); j++) { - EXPECT_TRUE(check_equal_vector( - inout_host.data() + j * row_elements_to_conjugate_even_components(sizes.back()), - input.data() + j * sizes.back(), sizes.back(), abs_error_margin, rel_error_margin, - std::cout)); - } - } - else { - EXPECT_TRUE(check_equal_vector(inout_host.data(), input.data(), input.size(), - abs_error_margin, rel_error_margin, std::cout)); + std::vector fwd_data_ref = input; + // account for scaling that occurs during DFT + std::for_each(fwd_data_ref.begin(), fwd_data_ref.end(), + [this](auto& x) { x *= static_cast(forward_elements); }); + + for (std::int64_t i = 0; i < batches; i++) { + EXPECT_TRUE(check_equal_strided( + inout_host.data() + forward_distance * i, fwd_data_ref.data() + ref_distance * i, sizes, + modified_strides_fwd, abs_error_margin, rel_error_margin, std::cout)); } return !::testing::Test::HasFailure(); @@ -187,97 +124,88 @@ int DFT_Test::test_in_place_USM() { return test_skipped; } - const int64_t container_size_total = - domain == oneapi::mkl::dft::domain::REAL - ? (size_total / sizes.back()) * row_elements_to_conjugate_even_components(sizes.back()) - : size_total; - const int64_t container_size_per_transform = container_size_total / batches; - const std::int64_t backward_elements = domain == oneapi::mkl::dft::domain::REAL - ? container_size_per_transform / 2 - : container_size_per_transform; - - descriptor_t descriptor = { sizes }; - descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - container_size_per_transform); - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_elements); - descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / forward_elements)); + auto modified_strides_fwd = this->strides_fwd; + auto modified_strides_bwd = this->strides_bwd; + if (domain == oneapi::mkl::dft::domain::REAL) { + // both input and output strides must be set + auto default_conjuate_strides = get_conjugate_even_complex_strides(sizes); + std::ptrdiff_t rank = static_cast(sizes.size()); - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - const auto real_strides = get_conjugate_even_real_component_strides(sizes); - const auto complex_strides = get_conjugate_even_complex_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, real_strides.data()); - descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - complex_strides.data()); - } - - commit_descriptor(descriptor, sycl_queue); - - auto ua_input = usm_allocator_t(cxt, *dev); - std::vector inout(container_size_total, ua_input); - - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - copy_strided(sizes, input, inout); + if (modified_strides_fwd.size() == 0) { + modified_strides_fwd = std::vector( + default_conjuate_strides.begin(), default_conjuate_strides.begin() + rank + 1); + std::transform(modified_strides_fwd.begin() + 1, modified_strides_fwd.begin() + rank, + modified_strides_fwd.begin() + 1, [](std::int64_t& s) { return 2 * s; }); + } + if (modified_strides_bwd.size() == 0) { + modified_strides_bwd = std::vector( + default_conjuate_strides.begin(), default_conjuate_strides.begin() + rank + 1); + } } else { - std::copy(input.begin(), input.end(), inout.begin()); - } - - try { - std::vector dependencies; - oneapi::mkl::dft::compute_forward(descriptor, inout.data(), - dependencies) - .wait(); - } - catch (oneapi::mkl::unimplemented& e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; + // General consistency requirements for in-place complex domain transforms require that strides are the same forward and backward. + modified_strides_fwd = modified_strides_bwd; } - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - std::vector conjugate_even_ref = - get_conjugate_even_ref(sizes, batches, out_host_ref); - EXPECT_TRUE(check_equal_vector(inout.data(), conjugate_even_ref.data(), inout.size(), - abs_error_margin, rel_error_margin, std::cout)); - } - else { - EXPECT_TRUE(check_equal_vector(inout.data(), out_host_ref.data(), inout.size(), - abs_error_margin, rel_error_margin, std::cout)); - } + auto [forward_distance, backward_distance] = + get_default_distances(sizes, modified_strides_fwd, modified_strides_bwd); + auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>()); + descriptor_t descriptor = { sizes }; + descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, + oneapi::mkl::dft::config_value::INPLACE); if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - const auto real_strides = get_conjugate_even_real_component_strides(sizes); - const auto complex_strides = get_conjugate_even_complex_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, complex_strides.data()); - descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, real_strides.data()); - commit_descriptor(descriptor, sycl_queue); + descriptor.set_value(oneapi::mkl::dft::config_param::CONJUGATE_EVEN_STORAGE, + oneapi::mkl::dft::config_value::COMPLEX_COMPLEX); + descriptor.set_value(oneapi::mkl::dft::config_param::PACKED_FORMAT, + oneapi::mkl::dft::config_value::CCE_FORMAT); } - - try { - std::vector dependencies; - sycl::event done = - oneapi::mkl::dft::compute_backward, - FwdInputType>(descriptor, inout.data()); - done.wait(); + descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance); + if (modified_strides_fwd.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, + modified_strides_fwd.data()); } - catch (oneapi::mkl::unimplemented& e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; + if (modified_strides_bwd.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, + modified_strides_bwd.data()); } + commit_descriptor(descriptor, sycl_queue); - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - for (int j = 0; j < size_total / sizes.back(); j++) { - EXPECT_TRUE(check_equal_vector( - inout.data() + j * row_elements_to_conjugate_even_components(sizes.back()), - input.data() + j * sizes.back(), sizes.back(), abs_error_margin, rel_error_margin, - std::cout)); - } - } - else { - EXPECT_TRUE(check_equal_vector(inout.data(), input.data(), input.size(), abs_error_margin, - rel_error_margin, std::cout)); + auto ua_input = usm_allocator_t(cxt, *dev); + std::vector inout( + strided_copy(input, sizes, modified_strides_fwd, batches, forward_distance, ua_input), + ua_input); + int real_multiplier = (domain == oneapi::mkl::dft::domain::REAL ? 2 : 1); + inout.resize( + cast_unsigned(std::max(forward_distance, real_multiplier * backward_distance) * batches + + real_multiplier * get_default(modified_strides_bwd, 0, 0L))); + + std::vector no_dependencies; + oneapi::mkl::dft::compute_forward(descriptor, inout.data(), + no_dependencies) + .wait_and_throw(); + + for (std::int64_t i = 0; i < batches; i++) { + EXPECT_TRUE(check_equal_strided( + reinterpret_cast(inout.data()) + backward_distance * i, + out_host_ref.data() + ref_distance * i, sizes, modified_strides_bwd, abs_error_margin, + rel_error_margin, std::cout)); + } + + sycl::event done = + oneapi::mkl::dft::compute_backward, + FwdInputType>(descriptor, inout.data(), no_dependencies); + done.wait_and_throw(); + + std::for_each(input.begin(), input.end(), + [this](auto& x) { x *= static_cast(forward_elements); }); + + for (std::int64_t i = 0; i < batches; i++) { + EXPECT_TRUE(check_equal_strided( + inout.data() + forward_distance * i, input.data() + ref_distance * i, sizes, + modified_strides_fwd, abs_error_margin, rel_error_margin, std::cout)); } return !::testing::Test::HasFailure(); diff --git a/tests/unit_tests/dft/include/compute_inplace_real_real.hpp b/tests/unit_tests/dft/include/compute_inplace_real_real.hpp index fed039803..d4af1a44a 100644 --- a/tests/unit_tests/dft/include/compute_inplace_real_real.hpp +++ b/tests/unit_tests/dft/include/compute_inplace_real_real.hpp @@ -22,28 +22,28 @@ #include "compute_tester.hpp" -/* Test is not implemented because currently there are no available dft implementations. - * These are stubs to make sure that dft::oneapi::mkl::unimplemented exception is thrown */ template int DFT_Test::test_in_place_real_real_USM() { if (!init(MemoryAccessModel::usm)) { return test_skipped; } + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + std::cout << "skipping real split tests as they are not supported" << std::endl; - try { + return test_skipped; + } + else { descriptor_t descriptor{ sizes }; - + PrecisionType backward_scale = 1.f / static_cast(forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::config_value::REAL_REAL); descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - static_cast(forward_elements)); - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, - static_cast(forward_elements)); - descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, - (1.0 / forward_elements)); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, backward_scale); + commit_descriptor(descriptor, sycl_queue); auto ua_input = usm_allocator_t(cxt, *dev); @@ -53,69 +53,100 @@ int DFT_Test::test_in_place_real_real_USM() { std::copy(input_re.begin(), input_re.end(), inout_re.begin()); std::copy(input_im.begin(), input_im.end(), inout_im.begin()); - std::vector dependencies; - sycl::event done = oneapi::mkl::dft::compute_forward( - descriptor, inout_re.data(), inout_im.data(), dependencies); - done.wait(); + std::vector no_dependencies; + oneapi::mkl::dft::compute_forward( + descriptor, inout_re.data(), inout_im.data(), no_dependencies) + .wait_and_throw(); - done = oneapi::mkl::dft::compute_backward, - PrecisionType>(descriptor, inout_re.data(), - inout_im.data(), dependencies); - done.wait(); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; - } + std::vector output_data(size_total); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { inout_re[i], inout_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), out_host_ref.data(), output_data.size(), + abs_error_margin, rel_error_margin, std::cout)); + + oneapi::mkl::dft::compute_backward, + PrecisionType>(descriptor, inout_re.data(), + inout_im.data(), no_dependencies) + .wait_and_throw(); - /* Once implementations exist, results will need to be verified */ - EXPECT_TRUE(false); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { inout_re[i], inout_im[i] }; + } - return !::testing::Test::HasFailure(); + EXPECT_TRUE(check_equal_vector(output_data.data(), input.data(), input.size(), + abs_error_margin, rel_error_margin, std::cout)); + + return !::testing::Test::HasFailure(); + } } -/* Test is not implemented because currently there are no available dft implementations. - * These are stubs to make sure that dft::oneapi::mkl::unimplemented exception is thrown */ template int DFT_Test::test_in_place_real_real_buffer() { if (!init(MemoryAccessModel::buffer)) { return test_skipped; } - try { + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + std::cout << "skipping real split tests as they are not supported" << std::endl; + + return test_skipped; + } + else { descriptor_t descriptor{ sizes }; + PrecisionType backward_scale = 1.f / static_cast(forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::config_value::REAL_REAL); descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - static_cast(forward_elements)); - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, - static_cast(forward_elements)); - descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, - (1.0 / forward_elements)); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, backward_scale); + commit_descriptor(descriptor, sycl_queue); - sycl::buffer inout_re_buf{ input_re.data(), sycl::range<1>(size_total) }; - sycl::buffer inout_im_buf{ input_im.data(), sycl::range<1>(size_total) }; + std::vector host_inout_re(size_total, static_cast(0)); + std::vector host_inout_im(size_total, static_cast(0)); + std::copy(input_re.begin(), input_re.end(), host_inout_re.begin()); + std::copy(input_im.begin(), input_im.end(), host_inout_im.begin()); + + sycl::buffer inout_re_buf{ host_inout_re.data(), + sycl::range<1>(size_total) }; + sycl::buffer inout_im_buf{ host_inout_im.data(), + sycl::range<1>(size_total) }; oneapi::mkl::dft::compute_forward(descriptor, inout_re_buf, inout_im_buf); + { + auto acc_inout_re = inout_re_buf.get_host_access(); + auto acc_inout_im = inout_im_buf.get_host_access(); + std::vector output_data(size_total, static_cast(0)); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { acc_inout_re[i], acc_inout_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), out_host_ref.data(), + output_data.size(), abs_error_margin, rel_error_margin, + std::cout)); + } + oneapi::mkl::dft::compute_backward, PrecisionType>(descriptor, inout_re_buf, inout_im_buf); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; - } - /* Once implementations exist, results will need to be verified */ - EXPECT_TRUE(false); - - return !::testing::Test::HasFailure(); + { + auto acc_inout_re = inout_re_buf.get_host_access(); + auto acc_inout_im = inout_im_buf.get_host_access(); + std::vector output_data(size_total, static_cast(0)); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { acc_inout_re[i], acc_inout_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), input.data(), input.size(), + abs_error_margin, rel_error_margin, std::cout)); + } + return !::testing::Test::HasFailure(); + } } #endif //ONEMKL_COMPUTE_INPLACE_REAL_REAL_HPP diff --git a/tests/unit_tests/dft/include/compute_out_of_place.hpp b/tests/unit_tests/dft/include/compute_out_of_place.hpp index 19b12386f..df5e1e323 100644 --- a/tests/unit_tests/dft/include/compute_out_of_place.hpp +++ b/tests/unit_tests/dft/include/compute_out_of_place.hpp @@ -23,93 +23,76 @@ #include "compute_tester.hpp" #include -template -std::int64_t get_backward_row_size(const std::vector &sizes) noexcept { - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - return sizes.back() / 2 + 1; - } - else { - return sizes.back(); - } -} - template int DFT_Test::test_out_of_place_buffer() { if (!init(MemoryAccessModel::buffer)) { return test_skipped; } - const auto backward_distance = std::accumulate( - sizes.begin(), sizes.end() - 1, get_backward_row_size(sizes), std::multiplies<>()); + auto [forward_distance, backward_distance] = + get_default_distances(sizes, strides_fwd, strides_bwd); + auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>()); descriptor_t descriptor{ sizes }; descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + descriptor.set_value(oneapi::mkl::dft::config_param::CONJUGATE_EVEN_STORAGE, + oneapi::mkl::dft::config_value::COMPLEX_COMPLEX); + descriptor.set_value(oneapi::mkl::dft::config_param::PACKED_FORMAT, + oneapi::mkl::dft::config_value::CCE_FORMAT); + } descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance); descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance); - descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / forward_elements)); - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + if (strides_fwd.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd.data()); + } + if (strides_bwd.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd.data()); + } + else if constexpr (domain == oneapi::mkl::dft::domain::REAL) { const auto complex_strides = get_conjugate_even_complex_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - complex_strides.data()); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, complex_strides.data()); } commit_descriptor(descriptor, sycl_queue); + std::vector fwd_data( + strided_copy(input, sizes, strides_fwd, batches, forward_distance)); - std::vector fwd_data(input); - + auto tmp = std::vector( + cast_unsigned(backward_distance * batches + get_default(strides_bwd, 0, 0L)), 0); { sycl::buffer fwd_buf{ fwd_data }; - sycl::buffer bwd_buf{ sycl::range<1>(backward_distance * batches) }; + sycl::buffer bwd_buf{ tmp }; - try { - oneapi::mkl::dft::compute_forward( - descriptor, fwd_buf, bwd_buf); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; - } + oneapi::mkl::dft::compute_forward( + descriptor, fwd_buf, bwd_buf); { - auto acc_bwd = bwd_buf.template get_host_access(); + auto acc_bwd = bwd_buf.get_host_access(); auto bwd_ptr = acc_bwd.get_pointer(); - auto ref_iter = out_host_ref.begin(); - const auto ref_row_stride = sizes.back(); - const auto backward_row_stride = get_backward_row_size(sizes); - const auto backward_row_elements = get_backward_row_size(sizes); - - while (ref_iter < out_host_ref.end()) { - EXPECT_TRUE(check_equal_vector(bwd_ptr, ref_iter, backward_row_elements, - abs_error_margin, rel_error_margin, std::cout)); - bwd_ptr += backward_row_stride; - ref_iter += ref_row_stride; + for (std::int64_t i = 0; i < batches; i++) { + EXPECT_TRUE(check_equal_strided( + bwd_ptr + backward_distance * i, out_host_ref.data() + ref_distance * i, sizes, + strides_bwd, abs_error_margin, rel_error_margin, std::cout)); } } - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - const auto complex_strides = get_conjugate_even_complex_strides(sizes); - auto real_strides = get_default_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, - complex_strides.data()); - descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - real_strides.data()); - commit_descriptor(descriptor, sycl_queue); - } + oneapi::mkl::dft::compute_backward, + FwdOutputType, FwdInputType>(descriptor, bwd_buf, + fwd_buf); + } - try { - oneapi::mkl::dft::compute_backward, - FwdOutputType, FwdInputType>(descriptor, bwd_buf, - fwd_buf); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; - } + // account for scaling that occurs during DFT + std::for_each(input.begin(), input.end(), + [this](auto &x) { x *= static_cast(forward_elements); }); + + for (std::int64_t i = 0; i < batches; i++) { + EXPECT_TRUE(check_equal_strided(fwd_data.data() + forward_distance * i, + input.data() + ref_distance * i, sizes, strides_fwd, + abs_error_margin, rel_error_margin, std::cout)); } - EXPECT_TRUE(check_equal_vector(fwd_data.data(), input.data(), input.size(), abs_error_margin, - rel_error_margin, std::cout)); return !::testing::Test::HasFailure(); } @@ -120,77 +103,68 @@ int DFT_Test::test_out_of_place_USM() { } const std::vector no_dependencies; - const auto backward_distance = std::accumulate( - sizes.begin(), sizes.end() - 1, get_backward_row_size(sizes), std::multiplies<>()); + auto [forward_distance, backward_distance] = + get_default_distances(sizes, strides_fwd, strides_bwd); descriptor_t descriptor{ sizes }; descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + descriptor.set_value(oneapi::mkl::dft::config_param::CONJUGATE_EVEN_STORAGE, + oneapi::mkl::dft::config_value::COMPLEX_COMPLEX); + descriptor.set_value(oneapi::mkl::dft::config_param::PACKED_FORMAT, + oneapi::mkl::dft::config_value::CCE_FORMAT); + } descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_distance); descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance); - descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / forward_elements)); - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + if (strides_fwd.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, strides_fwd.data()); + } + if (strides_bwd.size()) { + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, strides_bwd.data()); + } + else if constexpr (domain == oneapi::mkl::dft::domain::REAL) { const auto complex_strides = get_conjugate_even_complex_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, - complex_strides.data()); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, complex_strides.data()); } commit_descriptor(descriptor, sycl_queue); auto ua_input = usm_allocator_t(cxt, *dev); auto ua_output = usm_allocator_t(cxt, *dev); - std::vector fwd(input.begin(), input.end(), ua_input); - std::vector bwd(backward_distance * batches, ua_output); - - try { - oneapi::mkl::dft::compute_forward( - descriptor, fwd.data(), bwd.data(), no_dependencies) - .wait(); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; + std::vector fwd( + strided_copy(input, sizes, strides_fwd, batches, forward_distance, ua_input), ua_input); + std::vector bwd( + cast_unsigned(backward_distance * batches + get_default(strides_bwd, 0, 0L)), ua_output); + + oneapi::mkl::dft::compute_forward( + descriptor, fwd.data(), bwd.data(), no_dependencies) + .wait_and_throw(); + + auto bwd_ptr = &bwd[0]; + auto ref_distance = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>()); + for (std::int64_t i = 0; i < batches; i++) { + EXPECT_TRUE(check_equal_strided( + bwd_ptr + backward_distance * i, out_host_ref.data() + ref_distance * i, sizes, + strides_bwd, abs_error_margin, rel_error_margin, std::cout)); } - { - auto bwd_iter = bwd.begin(); - auto ref_iter = out_host_ref.begin(); - - const auto ref_row_stride = sizes.back(); - const auto backward_row_stride = get_backward_row_size(sizes); - const auto backward_row_elements = get_backward_row_size(sizes); - - while (ref_iter < out_host_ref.end()) { - EXPECT_TRUE(check_equal_vector(bwd_iter, ref_iter, backward_row_elements, - abs_error_margin, rel_error_margin, std::cout)); - bwd_iter += backward_row_stride; - ref_iter += ref_row_stride; - } - } + oneapi::mkl::dft::compute_backward, FwdOutputType, + FwdInputType>(descriptor, bwd.data(), fwd.data(), + no_dependencies) + .wait_and_throw(); - if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - const auto complex_strides = get_conjugate_even_complex_strides(sizes); - auto real_strides = get_default_strides(sizes); - descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, complex_strides.data()); - descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, real_strides.data()); - commit_descriptor(descriptor, sycl_queue); - } + // account for scaling that occurs during DFT + std::for_each(input.begin(), input.end(), + [this](auto &x) { x *= static_cast(forward_elements); }); - try { - oneapi::mkl::dft::compute_backward, - FwdOutputType, FwdInputType>(descriptor, bwd.data(), - fwd.data(), no_dependencies) - .wait(); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; + for (std::int64_t i = 0; i < batches; i++) { + EXPECT_TRUE(check_equal_strided(fwd.data() + forward_distance * i, + input.data() + ref_distance * i, sizes, strides_fwd, + abs_error_margin, rel_error_margin, std::cout)); } - EXPECT_TRUE(check_equal_vector(fwd.data(), input.data(), input.size(), abs_error_margin, - rel_error_margin, std::cout)); - return !::testing::Test::HasFailure(); } diff --git a/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp b/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp index 9a9730fd0..fb3ecb4f2 100644 --- a/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp +++ b/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp @@ -22,28 +22,30 @@ #include "compute_tester.hpp" -/* Test is not implemented because currently there are no available dft implementations. - * These are stubs to make sure that dft::oneapi::mkl::unimplemented exception is thrown */ template int DFT_Test::test_out_of_place_real_real_USM() { if (!init(MemoryAccessModel::usm)) { return test_skipped; } - try { + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + std::cout << "skipping real split tests as they are not supported" << std::endl; + + return test_skipped; + } + else { descriptor_t descriptor{ sizes }; + PrecisionType backward_scale = 1.f / static_cast(forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::config_value::REAL_REAL); descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - static_cast(forward_elements)); - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, - static_cast(forward_elements)); - descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, - (1.0 / forward_elements)); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, backward_scale); + commit_descriptor(descriptor, sycl_queue); auto ua_input = usm_allocator_t(cxt, *dev); @@ -59,50 +61,59 @@ int DFT_Test::test_out_of_place_real_real_USM() { std::copy(input_re.begin(), input_re.end(), in_re.begin()); std::copy(input_im.begin(), input_im.end(), in_im.begin()); - std::vector dependencies; - sycl::event done = - oneapi::mkl::dft::compute_forward( - descriptor, in_re.data(), in_im.data(), out_re.data(), out_im.data(), dependencies); - done.wait(); + std::vector no_dependencies; - done = oneapi::mkl::dft::compute_backward, - PrecisionType, PrecisionType>( - descriptor, out_re.data(), out_im.data(), out_back_re.data(), out_back_im.data()); - done.wait(); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; - } + oneapi::mkl::dft::compute_forward( + descriptor, in_re.data(), in_im.data(), out_re.data(), out_im.data(), no_dependencies) + .wait_and_throw(); + std::vector output_data(size_total); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { out_re[i], out_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), out_host_ref.data(), output_data.size(), + abs_error_margin, rel_error_margin, std::cout)); + + oneapi::mkl::dft::compute_backward, + PrecisionType, PrecisionType>( + descriptor, out_re.data(), out_im.data(), out_back_re.data(), out_back_im.data(), + no_dependencies) + .wait_and_throw(); - /* Once implementations exist, results will need to be verified */ - EXPECT_TRUE(false); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { out_back_re[i], out_back_im[i] }; + } + + EXPECT_TRUE(check_equal_vector(output_data.data(), input.data(), input.size(), + abs_error_margin, rel_error_margin, std::cout)); + } return !::testing::Test::HasFailure(); } -/* Test is not implemented because currently there are no available dft implementations. - * These are stubs to make sure that dft::oneapi::mkl::unimplemented exception is thrown */ template int DFT_Test::test_out_of_place_real_real_buffer() { if (!init(MemoryAccessModel::buffer)) { return test_skipped; } - try { + if constexpr (domain == oneapi::mkl::dft::domain::REAL) { + std::cout << "skipping real split tests as they are not supported" << std::endl; + + return test_skipped; + } + else { descriptor_t descriptor{ sizes }; + PrecisionType backward_scale = 1.f / static_cast(forward_elements); descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::config_value::REAL_REAL); descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches); - descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, - static_cast(forward_elements)); - descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, - static_cast(forward_elements)); - descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, - (1.0 / forward_elements)); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, forward_elements); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, backward_scale); + commit_descriptor(descriptor, sycl_queue); sycl::buffer in_dev_re{ input_re.data(), sycl::range<1>(size_total) }; @@ -115,17 +126,33 @@ int DFT_Test::test_out_of_place_real_real_buffer() { oneapi::mkl::dft::compute_forward( descriptor, in_dev_re, in_dev_im, out_dev_re, out_dev_im); + { + auto acc_out_re = out_dev_re.get_host_access(); + auto acc_out_im = out_dev_im.get_host_access(); + std::vector output_data(size_total, static_cast(0)); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { acc_out_re[i], acc_out_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), out_host_ref.data(), + output_data.size(), abs_error_margin, rel_error_margin, + std::cout)); + } + oneapi::mkl::dft::compute_backward, PrecisionType, PrecisionType>( descriptor, out_dev_re, out_dev_im, out_back_dev_re, out_back_dev_im); - } - catch (oneapi::mkl::unimplemented &e) { - std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; - return test_skipped; - } - /* Once implementations exist, results will need to be verified */ - EXPECT_TRUE(false); + { + auto acc_back_out_re = out_back_dev_re.get_host_access(); + auto acc_back_out_im = out_back_dev_im.get_host_access(); + std::vector output_data(size_total, static_cast(0)); + for (std::size_t i = 0; i < output_data.size(); ++i) { + output_data[i] = { acc_back_out_re[i], acc_back_out_im[i] }; + } + EXPECT_TRUE(check_equal_vector(output_data.data(), input.data(), input.size(), + abs_error_margin, rel_error_margin, std::cout)); + } + } return !::testing::Test::HasFailure(); } diff --git a/tests/unit_tests/dft/include/compute_tester.hpp b/tests/unit_tests/dft/include/compute_tester.hpp index 1b2238555..17ffac0cb 100644 --- a/tests/unit_tests/dft/include/compute_tester.hpp +++ b/tests/unit_tests/dft/include/compute_tester.hpp @@ -52,9 +52,11 @@ struct DFT_Test { enum class MemoryAccessModel { buffer, usm }; const std::vector sizes; + const std::vector strides_fwd; + const std::vector strides_bwd; const std::int64_t batches; const std::int64_t forward_elements; - const std::int64_t size_total; + const std::size_t size_total; double abs_error_margin{ 0 }; double rel_error_margin{ 0 }; @@ -67,12 +69,16 @@ struct DFT_Test { std::vector input_im; std::vector out_host_ref; - DFT_Test(sycl::device* dev, std::vector sizes_, std::int64_t batches_) + DFT_Test(sycl::device* dev, std::vector sizes_, + std::vector strides_fwd, std::vector strides_bwd, + std::int64_t batches_) : sizes{ std::move(sizes_) }, + strides_fwd(std::move(strides_fwd)), + strides_bwd(std::move(strides_bwd)), batches{ batches_ }, forward_elements{ std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>{}) }, - size_total{ forward_elements * batches }, + size_total{ cast_unsigned(forward_elements * batches) }, dev{ dev }, sycl_queue{ *dev, exception_handler }, cxt{ sycl_queue.get_context() } { @@ -88,13 +94,13 @@ struct DFT_Test { rand_vector(input, size_total); if constexpr (domain == oneapi::mkl::dft::domain::REAL) { - for (int i = 0; i < input.size(); ++i) { + for (std::size_t i = 0; i < input.size(); ++i) { input_re[i] = { input[i] }; input_im[i] = 0; } } else { - for (int i = 0; i < input.size(); ++i) { + for (std::size_t i = 0; i < input.size(); ++i) { input_re[i] = { input[i].real() }; input_im[i] = { input[i].imag() }; } @@ -130,8 +136,8 @@ struct DFT_Test { }); // Heuristic for the average-case error margins abs_error_margin = - std::abs(max_norm_ref) * std::log2(static_cast(forward_elements)); - rel_error_margin = 5.0 * std::log2(static_cast(forward_elements)); + 10 * std::abs(max_norm_ref) * std::log2(static_cast(forward_elements)); + rel_error_margin = 200.0 * std::log2(static_cast(forward_elements)); return !skip_test(mem_acc); } diff --git a/tests/unit_tests/dft/include/parseval_check.hpp b/tests/unit_tests/dft/include/parseval_check.hpp new file mode 100644 index 000000000..ece6f7d31 --- /dev/null +++ b/tests/unit_tests/dft/include/parseval_check.hpp @@ -0,0 +1,81 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#ifndef ONEMKL_PARSEVAL_CHECK_HPP +#define ONEMKL_PARSEVAL_CHECK_HPP + +#include +#include +#include +#include +#include + +#include "test_common.hpp" + +/** Use Parseval's theorem to verify the output of DFT. This does not guarantee that the output + * of the DFT is correct, and is only a sanity check. + * + * Check Sum(|in[i]|^2) == Sum(|out[i]|^2). + * + * @tparam TypeFwd Forward domain type + * @tparam TypeBwd Backward domain type + * @param dft_len DFT size + * @param in forward domain data + * @param out bwd domain data + * @param rescale_forward A value to multiply the in data by. +*/ +template +bool parseval_check(std::size_t dft_len, const TypeFwd* in, TypeBwd* out, + TypeFwd rescale_forward = 1) { + static_assert(is_complex()); + bool complex_forward = is_complex(); + auto bwd_len = complex_forward ? dft_len : dft_len / 2 + 1; + + float in_sum{ 0 }; + float out_sum{ 0 }; + for (std::size_t i{ 0 }; i < dft_len; ++i) { + in_sum += static_cast(std::abs(in[i] * rescale_forward) * + std::abs(in[i] * rescale_forward)); + } + if (complex_forward) { + for (std::size_t i{ 0 }; i < bwd_len; ++i) { + out_sum += static_cast(std::abs(out[i]) * std::abs(out[i])); + } + } + else { + for (std::size_t i{ 0 }; i < bwd_len - 1; ++i) { + out_sum += static_cast(std::abs(out[i]) * std::abs(out[i])); + } + out_sum *= 2; + out_sum += static_cast(std::abs(out[bwd_len - 1]) * std::abs(out[bwd_len - 1])); + } + out_sum /= static_cast(dft_len); + auto max_norm_ref = *std::max_element( + in, in + dft_len, [](const auto& a, const auto& b) { return std::abs(a) < std::abs(b); }); + // Heuristic for the average-case error margins + auto abs_error_margin = 10 * std::abs(max_norm_ref) * std::log2(static_cast(dft_len)); + if (std::abs(in_sum - out_sum) > abs_error_margin) { + std::cout << "Failed check with Parseval's theorem: Fwd sum = " << in_sum + << ", Bwd sum = " << out_sum << " (tol = " << abs_error_margin << ")" + << std::endl; + return false; + } + return true; +} +#endif // ONEMKL_PARSEVAL_CHECK_HPP diff --git a/tests/unit_tests/dft/include/reference_dft.hpp b/tests/unit_tests/dft/include/reference_dft.hpp index 4ee511661..236edc7b0 100644 --- a/tests/unit_tests/dft/include/reference_dft.hpp +++ b/tests/unit_tests/dft/include/reference_dft.hpp @@ -20,6 +20,7 @@ #ifndef ONEMKL_REFERENCE_DFT_HPP #define ONEMKL_REFERENCE_DFT_HPP +#include #include #include #include @@ -31,7 +32,7 @@ namespace detail { using ref_t = long double; /* Do the calculations using long double */ template -void reference_forward_dft_impl(const TypeIn *in, TypeOut *out, size_t N, size_t stride) { +void reference_forward_dft_impl(const TypeIn *in, TypeOut *out, std::size_t N, std::size_t stride) { static_assert(is_complex(), "Output type of DFT must be complex"); constexpr ref_t TWOPI = 2.0L * 3.141592653589793238462643383279502884197L; @@ -53,22 +54,20 @@ struct reference {}; template struct reference { - static void forward_dft(const std::vector &sizes, const TypeIn *in, - TypeOut *out) { + static void forward_dft(const std::vector &sizes, const TypeIn *in, TypeOut *out) { reference_forward_dft_impl(in, out, sizes[0], 1); } }; template struct reference { - static void forward_dft(const std::vector &sizes, const TypeIn *in, - TypeOut *out) { - const auto elements = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>{}); + static void forward_dft(const std::vector &sizes, const TypeIn *in, TypeOut *out) { + const auto elements = std::accumulate(sizes.begin(), sizes.end(), 1U, std::multiplies<>{}); std::vector> tmp(elements); - for (size_t i = 0; i < elements; i += sizes[1]) { + for (std::size_t i = 0; i < elements; i += sizes[1]) { reference_forward_dft_impl(in + i, tmp.data() + i, sizes[1], 1); } - for (size_t i = 0; i < sizes[1]; i++) { + for (std::size_t i = 0; i < sizes[1]; i++) { reference_forward_dft_impl(tmp.data() + i, out + i, sizes[0], sizes[1]); } } @@ -76,21 +75,20 @@ struct reference { template struct reference { - static void forward_dft(const std::vector &sizes, const TypeIn *in, - TypeOut *out) { - const auto elements = std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>{}); + static void forward_dft(const std::vector &sizes, const TypeIn *in, TypeOut *out) { + const auto elements = std::accumulate(sizes.begin(), sizes.end(), 1U, std::multiplies<>{}); std::vector> tmp1(elements); std::vector> tmp2(elements); - for (size_t i = 0; i < elements; i += sizes[2]) { + for (std::size_t i = 0; i < elements; i += sizes[2]) { reference_forward_dft_impl(in + i, tmp1.data() + i, sizes[2], 1); } - for (size_t j = 0; j < elements; j += sizes[1] * sizes[2]) { - for (size_t i = 0; i < sizes[2]; i++) { + for (std::size_t j = 0; j < elements; j += sizes[1] * sizes[2]) { + for (std::size_t i = 0; i < sizes[2]; i++) { reference_forward_dft_impl(tmp1.data() + i + j, tmp2.data() + i + j, sizes[1], sizes[2]); } } - for (size_t i = 0; i < sizes[1] * sizes[2]; i++) { + for (std::size_t i = 0; i < sizes[1] * sizes[2]; i++) { reference_forward_dft_impl(tmp2.data() + i, out + i, sizes[0], sizes[1] * sizes[2]); } } @@ -115,13 +113,17 @@ struct reference { **/ template void reference_forward_dft(const std::vector &sizes, const TypeIn *in, TypeOut *out) { - switch (sizes.size()) { - case 1: detail::reference::forward_dft(sizes, in, out); break; - case 2: detail::reference::forward_dft(sizes, in, out); break; - case 3: detail::reference::forward_dft(sizes, in, out); break; + std::vector unsigned_sizes(sizes.size()); + std::transform(sizes.begin(), sizes.end(), unsigned_sizes.begin(), + [](std::int64_t size) { return cast_unsigned(size); }); + switch (unsigned_sizes.size()) { + case 1: detail::reference::forward_dft(unsigned_sizes, in, out); break; + case 2: detail::reference::forward_dft(unsigned_sizes, in, out); break; + case 3: detail::reference::forward_dft(unsigned_sizes, in, out); break; default: - throw oneapi::mkl::unimplemented("reference_dft", "forward_dft", - "dft with size " + std::to_string(sizes.size())); + throw oneapi::mkl::unimplemented( + "reference_dft", "forward_dft", + "dft with size " + std::to_string(unsigned_sizes.size())); } } diff --git a/tests/unit_tests/dft/include/test_common.hpp b/tests/unit_tests/dft/include/test_common.hpp index 166d7f6a3..b13723105 100644 --- a/tests/unit_tests/dft/include/test_common.hpp +++ b/tests/unit_tests/dft/include/test_common.hpp @@ -24,6 +24,7 @@ #include #include #include +#include #include #if __has_include() @@ -49,6 +50,13 @@ constexpr bool is_complex() { return complex_info::is_complex; } +inline std::size_t cast_unsigned(std::int64_t i) { + if (i < 0) { + throw std::runtime_error("Unexpected negative value"); + } + return static_cast(i); +} + template bool check_equal(fp x, fp x_ref, double abs_error_mag, double rel_error_mag, std::ostream &out) { using fp_real = typename complex_info::real_type; @@ -65,8 +73,8 @@ bool check_equal(fp x, fp x_ref, double abs_error_mag, double rel_error_mag, std return std::numeric_limits::epsilon(); } }(); - const fp_real abs_bound = abs_error_mag * epsilon; - const fp_real rel_bound = rel_error_mag * epsilon; + const auto abs_bound = static_cast(abs_error_mag) * epsilon; + const auto rel_bound = static_cast(rel_error_mag) * epsilon; const auto aerr = std::abs(x - x_ref); const auto rerr = aerr / std::abs(x_ref); @@ -80,14 +88,20 @@ bool check_equal(fp x, fp x_ref, double abs_error_mag, double rel_error_mag, std } template -bool check_equal_vector(vec1 &&v, vec2 &&v_ref, int n, double abs_error_mag, double rel_error_mag, - std::ostream &out) { +bool check_equal_vector(vec1 &&v, vec2 &&v_ref, std::size_t n, double abs_error_mag, + double rel_error_mag, std::ostream &out) { constexpr int max_print = 20; int count = 0; bool good = true; for (std::size_t i = 0; i < n; ++i) { - if (!check_equal(v[i], v_ref[i], abs_error_mag, rel_error_mag, out)) { + // Allow to convert the unsigned index `i` to a signed one to keep this function generic and allow for `v` and `v_ref` to be a vector, a pointer or a random access iterator. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wsign-conversion" + auto res = v[i]; + auto ref = v_ref[i]; +#pragma clang diagnostic pop + if (!check_equal(res, ref, abs_error_mag, rel_error_mag, out)) { out << " at index i =" << i << "\n"; good = false; ++count; @@ -117,10 +131,10 @@ inline t rand_scalar() { } template -void rand_vector(vec &v, int n) { +void rand_vector(vec &v, std::size_t n) { using fp = typename vec::value_type; v.resize(n); - for (int i = 0; i < n; i++) { + for (std::size_t i = 0; i < n; i++) { v[i] = rand_scalar(); } } @@ -183,9 +197,149 @@ inline std::array get_default_strides(const std::vector +T get_default(const std::vector vec, std::size_t idx, T default_) { + if (idx >= vec.size()) { + return default_; + } + return vec[idx]; +} + +template +std::pair get_default_distances( + const std::vector &sizes, const std::vector &strides_fwd, + const std::vector &strides_bwd) { + std::int64_t size0 = sizes[0]; + std::int64_t size1 = get_default(sizes, 1, 1l); + std::int64_t size2 = get_default(sizes, 2, 1l); + std::int64_t size0_real = + domain == oneapi::mkl::dft::domain::REAL && sizes.size() == 1 ? size0 / 2 + 1 : size0; + std::int64_t size1_real = + domain == oneapi::mkl::dft::domain::REAL && sizes.size() == 2 ? size1 / 2 + 1 : size1; + std::int64_t size2_real = + domain == oneapi::mkl::dft::domain::REAL && sizes.size() == 3 ? size2 / 2 + 1 : size2; + std::int64_t backward_distance = size0_real * size1_real * size2_real; + std::int64_t forward_distance = size0 * size1 * size2; + if (strides_fwd.size() > 1) { + forward_distance = + std::max({ size0 * strides_fwd[1], size1 * get_default(strides_fwd, 2, 0l), + size2 * get_default(strides_fwd, 3, 0l) }); + } + if (strides_bwd.size() > 1) { + backward_distance = + std::max({ size0 * strides_bwd[1], size1 * get_default(strides_bwd, 2, 0l), + size2 * get_default(strides_bwd, 3, 0l) }); + } + if (in_place) { + forward_distance = + std::max(forward_distance, + backward_distance * (domain == oneapi::mkl::dft::domain::REAL ? 2L : 1L)); + } + return { forward_distance, backward_distance }; +} + +//up to 3 dimensions, empty strides = default +template > +std::vector strided_copy( + const T_vec &contiguous, const std::vector &sizes, + const std::vector &strides, std::int64_t batches, std::int64_t distance, + Allocator alloc = {}) { + if (strides.size() == 0) { + return { contiguous.begin(), contiguous.end(), alloc }; + } + using T = typename T_vec::value_type; + std::int64_t size0 = sizes[0]; + std::int64_t size1 = get_default(sizes, 1, 1l); + std::int64_t size2 = get_default(sizes, 2, 1l); + + std::int64_t stride0 = strides[0]; + std::int64_t stride1 = strides[1]; + std::int64_t stride2 = get_default(strides, 2, 0l); + std::int64_t stride3 = get_default(strides, 3, 0l); + std::vector res(cast_unsigned(distance * batches + stride0), alloc); + for (std::int64_t b = 0; b < batches; b++) { + for (std::int64_t i = 0; i < size0; i++) { + for (std::int64_t j = 0; j < size1; j++) { + for (std::int64_t k = 0; k < size2; k++) { + res[cast_unsigned(b * distance + i * stride1 + j * stride2 + k * stride3 + + stride0)] = + contiguous[cast_unsigned(((b * size0 + i) * size1 + j) * size2 + k)]; + } + } + } + } + return res; +} + +//up to 3 dimensions, empty strides = default +template +bool check_equal_strided(const vec1 &v, const vec2 &v_ref, std::vector sizes, + std::vector strides, double abs_error_mag, double rel_error_mag, + std::ostream &out) { + if (strides.size() == 0) { + std::array strides_arr; + if constexpr (ConjugateEvenStrides) { + strides_arr = get_conjugate_even_complex_strides(sizes); + } + else { + strides_arr = get_default_strides(sizes); + } + strides = { &strides_arr[0], &strides_arr[sizes.size() + 1] }; + } + using T = std::decay_t; + std::int64_t size0 = sizes[0]; + std::int64_t size1 = get_default(sizes, 1, 1l); + std::int64_t size2 = get_default(sizes, 2, 1l); + std::int64_t size0_real = ConjugateEvenStrides && sizes.size() == 1 ? size0 / 2 + 1 : size0; + std::int64_t size1_real = ConjugateEvenStrides && sizes.size() == 2 ? size1 / 2 + 1 : size1; + std::int64_t size2_real = ConjugateEvenStrides && sizes.size() == 3 ? size2 / 2 + 1 : size2; + + std::int64_t stride0 = strides[0]; + std::int64_t stride1 = strides[1]; + std::int64_t stride2 = get_default(strides, 2, 0l); + std::int64_t stride3 = get_default(strides, 3, 0l); + + constexpr int max_print = 20; + int count = 0; + bool good = true; + + for (std::int64_t i = 0; i < size0_real; i++) { + for (std::int64_t j = 0; j < size1_real; j++) { + for (std::int64_t k = 0; k < size2_real; k++) { + T res = v[cast_unsigned(i * stride1 + j * stride2 + k * stride3 + stride0)]; + T ref = v_ref[cast_unsigned((i * size1 + j) * size2 + k)]; + if (!check_equal(res, ref, abs_error_mag, rel_error_mag, out)) { + out << " at position " << i << ", " << j << ", " << k << "\n"; + out << " at indices " << i * stride1 + j * stride2 + k * stride3 + stride0 + << ", " << (i * size1 + j) * size2 + k << "\n"; + good = false; + ++count; + if (count > max_print) { + return good; + } + } + } + } + } + return good; +} + struct DFTParams { std::vector sizes; + std::vector strides_fwd; + std::vector strides_bwd; std::int64_t batches; + DFTParams(std::vector sizes, std::int64_t batches) + : sizes(sizes), + strides_fwd({}), + strides_bwd({}), + batches(batches) {} + DFTParams(std::vector sizes, std::vector strides_fwd, + std::vector strides_bwd, std::int64_t batches) + : sizes(sizes), + strides_fwd(strides_fwd), + strides_bwd(strides_bwd), + batches(batches) {} }; class DFTParamsPrint { @@ -193,18 +347,34 @@ class DFTParamsPrint { std::string operator()( testing::TestParamInfo> dev) const { auto [device, params] = dev.param; - auto [sizes, batches] = params; std::string info_name; - assert(sizes.size() > 0); + assert(params.sizes.size() > 0); info_name.append("sizes_"); // intersperse dimensions with "x" - std::for_each(sizes.begin(), sizes.end() - 1, + std::for_each(params.sizes.begin(), params.sizes.end() - 1, [&info_name](auto s) { info_name.append(std::to_string(s)).append("x"); }); - info_name.append(std::to_string(sizes.back())); + info_name.append(std::to_string(params.sizes.back())); + + if (params.strides_fwd.size() != 0) { + info_name.append("_fwd_strides_"); + // intersperse strides with "_" + std::for_each( + params.strides_fwd.begin(), params.strides_fwd.end() - 1, + [&info_name](auto s) { info_name.append(std::to_string(s)).append("_"); }); + info_name.append(std::to_string(params.strides_fwd.back())); + } + if (params.strides_bwd.size() != 0) { + info_name.append("_bwd_strides_"); + // intersperse strides with "_" + std::for_each( + params.strides_bwd.begin(), params.strides_bwd.end() - 1, + [&info_name](auto s) { info_name.append(std::to_string(s)).append("_"); }); + info_name.append(std::to_string(params.strides_bwd.back())); + } - info_name.append("_batches_").append(std::to_string(batches)); + info_name.append("_batches_").append(std::to_string(params.batches)); std::string dev_name = device->get_info(); std::for_each(dev_name.begin(), dev_name.end(), [](auto &c) { diff --git a/tests/unit_tests/dft/source/CMakeLists.txt b/tests/unit_tests/dft/source/CMakeLists.txt index a682d2872..364ad564f 100644 --- a/tests/unit_tests/dft/source/CMakeLists.txt +++ b/tests/unit_tests/dft/source/CMakeLists.txt @@ -17,7 +17,9 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== -set(DFT_SOURCES "compute_tests.cpp" "descriptor_tests.cpp") +set(DFT_SOURCES "compute_tests.cpp" "descriptor_tests.cpp" "workspace_external_tests.cpp") + +include(WarningsUtils) if (BUILD_SHARED_LIBS) add_library(dft_source_rt OBJECT ${DFT_SOURCES}) @@ -34,6 +36,7 @@ if (BUILD_SHARED_LIBS) else () target_link_libraries(dft_source_rt PUBLIC ONEMKL::SYCL::SYCL) endif () + target_link_libraries(dft_source_rt PRIVATE onemkl_warnings) endif () add_library(dft_source_ct OBJECT ${DFT_SOURCES}) @@ -47,11 +50,8 @@ target_include_directories(dft_source_ct ) if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) add_sycl_to_target(TARGET dft_source_ct SOURCES ${DFT_SOURCES}) - target_link_libraries(dft_source_ct PUBLIC onemkl) else () - target_link_libraries(dft_source_ct PUBLIC - onemkl - ONEMKL::SYCL::SYCL - ) + target_link_libraries(dft_source_ct PUBLIC ONEMKL::SYCL::SYCL) endif () +target_link_libraries(dft_source_ct PRIVATE onemkl_warnings) diff --git a/tests/unit_tests/dft/source/compute_tests.cpp b/tests/unit_tests/dft/source/compute_tests.cpp index 69b3eda3a..005f833ef 100644 --- a/tests/unit_tests/dft/source/compute_tests.cpp +++ b/tests/unit_tests/dft/source/compute_tests.cpp @@ -39,23 +39,48 @@ extern std::vector devices; namespace { -class ComputeTests_in_place +class ComputeTests_in_place_COMPLEX : public ::testing::TestWithParam> {}; -class ComputeTests_real_real_in_place +class ComputeTests_real_real_in_place_COMPLEX : public ::testing::TestWithParam> {}; -class ComputeTests_out_of_place +class ComputeTests_out_of_place_COMPLEX : public ::testing::TestWithParam> {}; -class ComputeTests_real_real_out_of_place +class ComputeTests_real_real_out_of_place_COMPLEX : public ::testing::TestWithParam> {}; -#define INSTANTIATE_TEST(PRECISION, DOMAIN, PLACE, LAYOUT, STORAGE) \ - TEST_P(ComputeTests##_##LAYOUT##PLACE, DOMAIN##_##PRECISION##_##PLACE##_##LAYOUT##STORAGE) { \ - auto test = \ - DFT_Test{ \ - std::get<0>(GetParam()), std::get<1>(GetParam()).sizes, \ - std::get<1>(GetParam()).batches \ - }; \ - EXPECT_TRUEORSKIP(test.test_##PLACE##_##LAYOUT##STORAGE()); \ +class ComputeTests_in_place_REAL + : public ::testing::TestWithParam> {}; +class ComputeTests_real_real_in_place_REAL + : public ::testing::TestWithParam> {}; +class ComputeTests_out_of_place_REAL + : public ::testing::TestWithParam> {}; +class ComputeTests_real_real_out_of_place_REAL + : public ::testing::TestWithParam> {}; + +#define INSTANTIATE_TEST(PRECISION, DOMAIN, PLACE, LAYOUT, STORAGE) \ + TEST_P(ComputeTests##_##LAYOUT##PLACE##_##DOMAIN, \ + DOMAIN##_##PRECISION##_##PLACE##_##LAYOUT##STORAGE) { \ + try { \ + auto test = DFT_Test{ \ + std::get<0>(GetParam()), std::get<1>(GetParam()).sizes, \ + std::get<1>(GetParam()).strides_fwd, std::get<1>(GetParam()).strides_bwd, \ + std::get<1>(GetParam()).batches \ + }; \ + EXPECT_TRUEORSKIP(test.test_##PLACE##_##LAYOUT##STORAGE()); \ + } \ + catch (oneapi::mkl::unimplemented & e) { \ + std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; \ + GTEST_SKIP(); \ + } \ + catch (std::exception & e) { \ + std::string msg = e.what(); \ + if (msg.find("FFT_UNIMPLEMENTED") != std::string::npos) { \ + std::cout << "Skipping test because: \"" << msg << "\"" << std::endl; \ + GTEST_SKIP(); \ + } \ + throw; \ + } \ } #define INSTANTIATE_TEST_DIMENSIONS_PRECISION_DOMAIN(PLACE, LAYOUT, STORAGE) \ @@ -73,39 +98,110 @@ class ComputeTests_real_real_out_of_place INSTANTIATE_TEST_DIMENSIONS_PRECISION_DOMAIN_PLACE_LAYOUT(buffer) INSTANTIATE_TEST_DIMENSIONS_PRECISION_DOMAIN_PLACE_LAYOUT(USM) -using shape = std::vector; +using shape = std::vector; using i64 = std::int64_t; -// Parameter format - { shape of transform, number of transforms } +// Parameter format - { shape of transform, number of transforms } or { shape, forward strides, backward strides, number of transforms } +// strides need to be chosen in a way that also makes sense for real transforms std::vector test_params{ - { shape{ 8 }, i64{ 1 } }, { shape{ 9 }, i64{ 2 } }, { shape{ 8 }, i64{ 27 } }, - { shape{ 22 }, i64{ 1 } }, { shape{ 128 }, i64{ 1 } }, - - { shape{ 4, 4 }, i64{ 1 } }, { shape{ 4, 4 }, i64{ 2 } }, { shape{ 4, 3 }, i64{ 27 } }, - { shape{ 7, 8 }, i64{ 1 } }, { shape{ 64, 5 }, i64{ 1 } }, - - { shape{ 2, 2, 2 }, i64{ 1 } }, { shape{ 2, 2, 3 }, i64{ 2 } }, { shape{ 2, 2, 2 }, i64{ 27 } }, - { shape{ 3, 7, 2 }, i64{ 1 } }, { shape{ 8, 8, 9 }, i64{ 1 } }, + { shape{ 8 }, i64{ 1 } }, + { shape{ 9 }, i64{ 2 } }, + { shape{ 8 }, i64{ 27 } }, + { shape{ 22 }, i64{ 1 } }, + { shape{ 128 }, i64{ 1 } }, + + { shape{ 4, 4 }, i64{ 1 } }, + { shape{ 4, 4 }, i64{ 2 } }, + { shape{ 4, 3 }, i64{ 9 } }, + { shape{ 7, 8 }, i64{ 1 } }, + { shape{ 64, 5 }, i64{ 1 } }, + + { shape{ 2, 2, 2 }, i64{ 1 } }, + { shape{ 2, 2, 3 }, i64{ 2 } }, + { shape{ 2, 2, 2 }, i64{ 27 } }, + { shape{ 3, 7, 2 }, i64{ 1 } }, + { shape{ 8, 8, 9 }, i64{ 1 } }, + + { shape{ 4, 3 }, shape{ 2, 3, 1 }, shape{ 2, 3, 1 }, i64{ 2 } }, + { shape{ 4, 3 }, shape{ 0, 4, 1 }, shape{ 0, 3, 1 }, i64{ 3 } }, + { shape{ 4, 3 }, shape{ 4, 6, 2 }, shape{ 2, 6, 2 }, i64{ 2 } }, + { shape{ 4, 3 }, shape{ 1, 1, 4 }, shape{ 1, 1, 4 }, i64{ 9 } }, + { shape{ 4, 4 }, shape{ 2, 4, 1 }, shape{ 0, 4, 1 }, i64{ 2 } }, + { shape{ 4, 4 }, shape{ 0, 1, 5 }, shape{ 0, 1, 4 }, i64{ 2 } }, + { shape{ 4, 4 }, shape{ 0, 1, 4 }, shape{ 0, 2, 9 }, i64{ 2 } }, + { shape{ 4, 4 }, shape{ 0, 7, 1 }, shape{ 0, 5, 1 }, i64{ 2 } }, + { shape{ 4, 4 }, shape{ 0, 8, 2 }, shape{ 0, 8, 2 }, i64{ 2 } }, + { shape{ 4, 4 }, shape{ 0, 4, 1 }, shape{ 0, 1, 4 }, i64{ 2 } }, + + { shape{ 4, 4, 4 }, shape{ 2, 1, 4, 16 }, shape{ 4, 1, 4, 16 }, i64{ 2 } }, + { shape{ 4, 4, 4 }, shape{ 4, 17, 4, 1 }, shape{ 4, 23, 5, 1 }, i64{ 2 } }, + { shape{ 4, 4, 4 }, shape{ 0, 32, 8, 2 }, shape{ 0, 32, 8, 2 }, i64{ 2 } }, + { shape{ 4, 4, 4 }, shape{ 2, 4, 1, 16 }, shape{ 1, 4, 16, 1 }, i64{ 2 } }, + { shape{ 4, 4, 4 }, shape{ 0, 1, 32, 8 }, shape{ 0, 1, 32, 8 }, i64{ 2 } }, +}; +std::vector test_params_real_in_place{ + { shape{ 8 }, i64{ 1 } }, + { shape{ 9 }, i64{ 2 } }, + { shape{ 8 }, i64{ 27 } }, + { shape{ 22 }, i64{ 1 } }, + { shape{ 128 }, i64{ 1 } }, + + { shape{ 4, 4 }, i64{ 1 } }, + { shape{ 4, 4 }, i64{ 2 } }, + { shape{ 4, 3 }, i64{ 9 } }, + { shape{ 7, 8 }, i64{ 1 } }, + { shape{ 64, 5 }, i64{ 1 } }, + + { shape{ 2, 2, 2 }, i64{ 1 } }, + { shape{ 2, 2, 3 }, i64{ 2 } }, + { shape{ 2, 2, 2 }, i64{ 27 } }, + { shape{ 3, 7, 2 }, i64{ 1 } }, + { shape{ 8, 8, 9 }, i64{ 1 } }, + + { shape{ 4, 3 }, shape{ 0, 4, 1 }, shape{ 0, 2, 1 }, i64{ 2 } }, + { shape{ 4, 3 }, shape{ 0, 6, 1 }, shape{ 0, 3, 1 }, i64{ 2 } }, + { shape{ 4, 3 }, shape{ 0, 8, 2 }, shape{ 0, 4, 2 }, i64{ 2 } }, + { shape{ 4, 3 }, shape{ 2, 4, 1 }, shape{ 1, 2, 1 }, i64{ 2 } }, + { shape{ 4, 3 }, shape{ 6, 1, 4 }, shape{ 3, 1, 4 }, i64{ 9 } }, + { shape{ 4, 3 }, shape{ 0, 1, 5 }, shape{ 0, 1, 5 }, i64{ 2 } }, + { shape{ 4, 3 }, shape{ 0, 3, 12 }, shape{ 0, 3, 12 }, i64{ 9 } }, + + { shape{ 4, 4, 4 }, shape{ 4, 1, 4, 16 }, shape{ 2, 1, 4, 16 }, i64{ 2 } }, + { shape{ 4, 4, 4 }, shape{ 0, 48, 12, 2 }, shape{ 0, 24, 6, 2 }, i64{ 2 } }, + { shape{ 4, 4, 4 }, shape{ 0, 1, 48, 8 }, shape{ 0, 1, 24, 8 }, i64{ 2 } }, }; -// not currently implemented apis -std::vector no_tests{}; - -INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_in_place, +INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_in_place_COMPLEX, testing::Combine(testing::ValuesIn(devices), testing::ValuesIn(test_params)), DFTParamsPrint{}); - -INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_real_real_in_place, - testing::Combine(testing::ValuesIn(devices), testing::ValuesIn(no_tests)), +INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_real_real_in_place_COMPLEX, + testing::Combine(testing::ValuesIn(devices), + testing::ValuesIn(test_params)), DFTParamsPrint{}); - -INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_out_of_place, +INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_out_of_place_COMPLEX, + testing::Combine(testing::ValuesIn(devices), + testing::ValuesIn(test_params)), + DFTParamsPrint{}); +INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_real_real_out_of_place_COMPLEX, testing::Combine(testing::ValuesIn(devices), testing::ValuesIn(test_params)), DFTParamsPrint{}); -INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_real_real_out_of_place, - testing::Combine(testing::ValuesIn(devices), testing::ValuesIn(no_tests)), +INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_in_place_REAL, + testing::Combine(testing::ValuesIn(devices), + testing::ValuesIn(test_params_real_in_place)), + DFTParamsPrint{}); +INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_real_real_in_place_REAL, + testing::Combine(testing::ValuesIn(devices), + testing::ValuesIn(test_params_real_in_place)), + DFTParamsPrint{}); +INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_out_of_place_REAL, + testing::Combine(testing::ValuesIn(devices), + testing::ValuesIn(test_params)), + DFTParamsPrint{}); +INSTANTIATE_TEST_SUITE_P(ComputeTestSuite, ComputeTests_real_real_out_of_place_REAL, + testing::Combine(testing::ValuesIn(devices), + testing::ValuesIn(test_params)), DFTParamsPrint{}); } // anonymous namespace diff --git a/tests/unit_tests/dft/source/descriptor_tests.cpp b/tests/unit_tests/dft/source/descriptor_tests.cpp index d3dd5ad21..a420eb1e2 100644 --- a/tests/unit_tests/dft/source/descriptor_tests.cpp +++ b/tests/unit_tests/dft/source/descriptor_tests.cpp @@ -42,7 +42,7 @@ constexpr std::int64_t default_1d_lengths = 4; const std::vector default_3d_lengths{ 124, 5, 3 }; template -inline void set_and_get_lengths(sycl::queue& sycl_queue) { +static void set_and_get_lengths() { /* Negative Testing */ { oneapi::mkl::dft::descriptor descriptor{ default_3d_lengths }; @@ -70,8 +70,6 @@ inline void set_and_get_lengths(sycl::queue& sycl_queue) { descriptor.get_value(oneapi::mkl::dft::config_param::DIMENSION, &dimensions_after_set); EXPECT_EQ(new_lengths, lengths_value); EXPECT_EQ(dimensions, dimensions_after_set); - - commit_descriptor(descriptor, sycl_queue); } /* >= 2D */ @@ -100,8 +98,11 @@ inline void set_and_get_lengths(sycl::queue& sycl_queue) { } } +// Test for deprecated functionality +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" template -inline void set_and_get_strides(sycl::queue& sycl_queue) { +static void set_and_get_io_strides() { oneapi::mkl::dft::descriptor descriptor{ default_3d_lengths }; EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, nullptr), @@ -134,6 +135,8 @@ inline void set_and_get_strides(sycl::queue& sycl_queue) { std::vector input_strides_before_set(strides_size); std::vector input_strides_after_set(strides_size); + std::vector fwd_strides_after_set(strides_size, -1); + std::vector bwd_strides_after_set(strides_size, -1); descriptor.get_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, input_strides_before_set.data()); @@ -141,7 +144,11 @@ inline void set_and_get_strides(sycl::queue& sycl_queue) { descriptor.set_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, input_strides_value.data()); descriptor.get_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, input_strides_after_set.data()); + descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides_after_set.data()); + descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides_after_set.data()); EXPECT_EQ(input_strides_value, input_strides_after_set); + EXPECT_EQ(std::vector(strides_size, 0), fwd_strides_after_set); + EXPECT_EQ(std::vector(strides_size, 0), bwd_strides_after_set); std::vector output_strides_before_set(strides_size); std::vector output_strides_after_set(strides_size); @@ -156,7 +163,68 @@ inline void set_and_get_strides(sycl::queue& sycl_queue) { } template -inline void set_and_get_values(sycl::queue& sycl_queue) { +static void set_and_get_fwd_bwd_strides() { + oneapi::mkl::dft::descriptor descriptor{ default_3d_lengths }; + + EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, nullptr), + oneapi::mkl::invalid_argument); + EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, nullptr), + oneapi::mkl::invalid_argument); + + constexpr std::int64_t strides_size = 4; + const std::int64_t default_stride_d1 = default_3d_lengths[2] * default_3d_lengths[1]; + const std::int64_t default_stride_d2 = default_3d_lengths[2]; + const std::int64_t default_stride_d3 = 1; + + std::vector default_strides_value{ 0, default_stride_d1, default_stride_d2, + default_stride_d3 }; + + std::vector fwd_strides_value; + std::vector bwd_strides_value; + if constexpr (domain == oneapi::mkl::dft::domain::COMPLEX) { + fwd_strides_value = { 50, default_stride_d1 * 2, default_stride_d2 * 2, + default_stride_d3 * 2 }; + bwd_strides_value = { 50, default_stride_d1 * 2, default_stride_d2 * 2, + default_stride_d3 * 2 }; + } + else { + fwd_strides_value = { 0, default_3d_lengths[1] * (default_3d_lengths[2] / 2 + 1) * 2, + (default_3d_lengths[2] / 2 + 1) * 2, 1 }; + bwd_strides_value = { 0, default_3d_lengths[1] * (default_3d_lengths[2] / 2 + 1), + (default_3d_lengths[2] / 2 + 1), 1 }; + } + + std::vector fwd_strides_before_set(strides_size); + std::vector fwd_strides_after_set(strides_size); + std::vector input_strides_after_set(strides_size, -1); + std::vector output_strides_after_set(strides_size, -1); + + descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, + fwd_strides_before_set.data()); + EXPECT_EQ(default_strides_value, fwd_strides_before_set); + descriptor.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides_value.data()); + descriptor.get_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides_after_set.data()); + descriptor.get_value(oneapi::mkl::dft::config_param::INPUT_STRIDES, + input_strides_after_set.data()); + descriptor.get_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES, + output_strides_after_set.data()); + EXPECT_EQ(fwd_strides_value, fwd_strides_after_set); + EXPECT_EQ(std::vector(strides_size, 0), input_strides_after_set); + EXPECT_EQ(std::vector(strides_size, 0), output_strides_after_set); + + std::vector bwd_strides_before_set(strides_size); + std::vector bwd_strides_after_set(strides_size); + descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, + bwd_strides_before_set.data()); + EXPECT_EQ(default_strides_value, bwd_strides_before_set); + descriptor.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides_value.data()); + descriptor.get_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides_after_set.data()); + EXPECT_EQ(bwd_strides_value, bwd_strides_after_set); +} +#pragma clang diagnostic pop + +template +static void set_and_get_values() { oneapi::mkl::dft::descriptor descriptor{ default_1d_lengths }; using Precision_Type = @@ -164,7 +232,7 @@ inline void set_and_get_values(sycl::queue& sycl_queue) { double>; { - Precision_Type forward_scale_set_value{ 143.5 }; + auto forward_scale_set_value = Precision_Type(143.5); Precision_Type forward_scale_before_set; Precision_Type forward_scale_after_set; @@ -179,7 +247,7 @@ inline void set_and_get_values(sycl::queue& sycl_queue) { } { - Precision_Type backward_scale_set_value{ 143.5 }; + auto backward_scale_set_value = Precision_Type(143.5); Precision_Type backward_scale_before_set; Precision_Type backward_scale_after_set; @@ -346,10 +414,27 @@ inline void set_and_get_values(sycl::queue& sycl_queue) { descriptor.get_value(oneapi::mkl::dft::config_param::PACKED_FORMAT, &value); EXPECT_EQ(oneapi::mkl::dft::config_value::CCE_FORMAT, value); } + + { + oneapi::mkl::dft::config_value value{ + oneapi::mkl::dft::config_value::COMMITTED + }; // Initialize with invalid value + descriptor.get_value(oneapi::mkl::dft::config_param::WORKSPACE_PLACEMENT, &value); + EXPECT_EQ(oneapi::mkl::dft::config_value::WORKSPACE_AUTOMATIC, value); + + descriptor.set_value(oneapi::mkl::dft::config_param::WORKSPACE_PLACEMENT, + oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL); + + value = oneapi::mkl::dft::config_value::COMMITTED; // Initialize with invalid value + descriptor.get_value(oneapi::mkl::dft::config_param::WORKSPACE_PLACEMENT, &value); + EXPECT_EQ(oneapi::mkl::dft::config_value::WORKSPACE_EXTERNAL, value); + descriptor.set_value(oneapi::mkl::dft::config_param::WORKSPACE_PLACEMENT, + oneapi::mkl::dft::config_value::WORKSPACE_AUTOMATIC); + } } template -inline void get_readonly_values(sycl::queue& sycl_queue) { +static void get_readonly_values() { oneapi::mkl::dft::descriptor descriptor{ default_1d_lengths }; oneapi::mkl::dft::domain domain_value; @@ -371,14 +456,10 @@ inline void get_readonly_values(sycl::queue& sycl_queue) { oneapi::mkl::dft::config_value commit_status; descriptor.get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status); EXPECT_EQ(commit_status, oneapi::mkl::dft::config_value::UNCOMMITTED); - - commit_descriptor(descriptor, sycl_queue); - descriptor.get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status); - EXPECT_EQ(commit_status, oneapi::mkl::dft::config_value::COMMITTED); } template -inline void set_readonly_values(sycl::queue& sycl_queue) { +static void set_readonly_values() { oneapi::mkl::dft::descriptor descriptor{ default_1d_lengths }; EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::FORWARD_DOMAIN, @@ -405,8 +486,16 @@ inline void set_readonly_values(sycl::queue& sycl_queue) { EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, oneapi::mkl::dft::config_value::UNCOMMITTED), oneapi::mkl::invalid_argument); +} +template +static void get_commited(sycl::queue& sycl_queue) { + oneapi::mkl::dft::descriptor descriptor{ default_1d_lengths }; commit_descriptor(descriptor, sycl_queue); + + oneapi::mkl::dft::config_value commit_status; + descriptor.get_value(oneapi::mkl::dft::config_param::COMMIT_STATUS, &commit_status); + EXPECT_EQ(commit_status, oneapi::mkl::dft::config_value::COMMITTED); } template @@ -429,25 +518,28 @@ inline void recommit_values(sycl::queue& sycl_queue) { std::vector argument_groups{ // not changeable // FORWARD_DOMAIN, PRECISION, DIMENSION, COMMIT_STATUS - { std::make_pair(config_param::LENGTHS, std::int64_t{ 10 }), - std::make_pair(config_param::FORWARD_SCALE, PrecisionType{ 1.2 }), - std::make_pair(config_param::BACKWARD_SCALE, PrecisionType{ 3.4 }) }, - { std::make_pair(config_param::NUMBER_OF_TRANSFORMS, std::int64_t{ 5 }), - std::make_pair(config_param::COMPLEX_STORAGE, config_value::COMPLEX_COMPLEX), + { std::make_pair(config_param::COMPLEX_STORAGE, config_value::COMPLEX_COMPLEX), std::make_pair(config_param::REAL_STORAGE, config_value::REAL_REAL), std::make_pair(config_param::CONJUGATE_EVEN_STORAGE, config_value::COMPLEX_COMPLEX) }, { std::make_pair(config_param::PLACEMENT, config_value::NOT_INPLACE), + std::make_pair(config_param::NUMBER_OF_TRANSFORMS, std::int64_t{ 5 }), +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wdeprecated-declarations" std::make_pair(config_param::INPUT_STRIDES, strides.data()), std::make_pair(config_param::OUTPUT_STRIDES, strides.data()), +#pragma clang diagnostic pop std::make_pair(config_param::FWD_DISTANCE, std::int64_t{ 60 }), std::make_pair(config_param::BWD_DISTANCE, std::int64_t{ 70 }) }, { std::make_pair(config_param::WORKSPACE, config_value::ALLOW), std::make_pair(config_param::ORDERING, config_value::ORDERED), std::make_pair(config_param::TRANSPOSE, bool{ false }), - std::make_pair(config_param::PACKED_FORMAT, config_value::CCE_FORMAT) } + std::make_pair(config_param::PACKED_FORMAT, config_value::CCE_FORMAT) }, + { std::make_pair(config_param::LENGTHS, std::int64_t{ 10 }), + std::make_pair(config_param::FORWARD_SCALE, PrecisionType(1.2)), + std::make_pair(config_param::BACKWARD_SCALE, PrecisionType(3.4)) } }; - for (int i = 0; i < argument_groups.size(); i += 1) { + for (std::size_t i = 0; i < argument_groups.size(); i += 1) { for (auto argument : argument_groups[i]) { std::visit([&descriptor, p = argument.first](auto&& a) { descriptor.set_value(p, a); }, argument.second); @@ -455,6 +547,10 @@ inline void recommit_values(sycl::queue& sycl_queue) { try { commit_descriptor(descriptor, sycl_queue); } + catch (oneapi::mkl::unimplemented e) { + std::cout << "unimplemented exception at index " << i << " with error : " << e.what() + << "\ncontinuing...\n"; + } catch (oneapi::mkl::exception& e) { FAIL() << "exception at index " << i << " with error : " << e.what(); } @@ -542,21 +638,73 @@ inline void swap_out_dead_queue(sycl::queue& sycl_queue) { } template -int test(sycl::device* dev) { +static int test_move() { + using config_param = oneapi::mkl::dft::config_param; + // Use forward distance to test an element copied by value (ie. not on heap) + std::int64_t fwdDistanceRef(123); + // Use the DFT dimensions to test heap allocated values. + { + // Move constructor + oneapi::mkl::dft::descriptor descriptor{ default_1d_lengths }; + descriptor.set_value(config_param::FWD_DISTANCE, fwdDistanceRef); + oneapi::mkl::dft::descriptor descMoved{ std::move(descriptor) }; + std::int64_t fwdDistance(0), dftLength(0); + descMoved.get_value(config_param::FWD_DISTANCE, &fwdDistance); + EXPECT_EQ(fwdDistance, fwdDistanceRef); + descMoved.get_value(config_param::LENGTHS, &dftLength); + EXPECT_EQ(default_1d_lengths, dftLength); + } + { + // Move assignment + oneapi::mkl::dft::descriptor descriptor{ default_1d_lengths }; + descriptor.set_value(config_param::FWD_DISTANCE, fwdDistanceRef); + oneapi::mkl::dft::descriptor descMoved{ default_1d_lengths }; + descMoved = std::move(descriptor); + std::int64_t fwdDistance(0), dftLength(0); + descMoved.get_value(config_param::FWD_DISTANCE, &fwdDistance); + EXPECT_EQ(fwdDistance, fwdDistanceRef); + descMoved.get_value(config_param::LENGTHS, &dftLength); + EXPECT_EQ(default_1d_lengths, dftLength); + } + + return !::testing::Test::HasFailure(); +} + +template +static int test_getter_setter() { + set_and_get_lengths(); + set_and_get_io_strides(); + set_and_get_fwd_bwd_strides(); + set_and_get_values(); + get_readonly_values(); + set_readonly_values(); + + return !::testing::Test::HasFailure(); +} + +template +int test_commit(sycl::device* dev) { sycl::queue sycl_queue(*dev, exception_handler); if constexpr (precision == oneapi::mkl::dft::precision::DOUBLE) { - if (!sycl_queue.get_device().has(sycl::aspect::fp64)) { + if (!dev->has(sycl::aspect::fp64)) { std::cout << "Device does not support double precision." << std::endl; return test_skipped; } } - set_and_get_lengths(sycl_queue); - set_and_get_strides(sycl_queue); - set_and_get_values(sycl_queue); - get_readonly_values(sycl_queue); - set_readonly_values(sycl_queue); + // test that descriptor is supported + try { + oneapi::mkl::dft::descriptor descriptor{ default_1d_lengths }; + commit_descriptor(descriptor, sycl_queue); + } + catch (oneapi::mkl::unimplemented& e) { + std::cout << "Skipping because simple commit not supported. Reason: \"" << e.what() + << "\"\n"; + return test_skipped; + } + + get_commited(sycl_queue); recommit_values(sycl_queue); change_queue_causes_wait(sycl_queue); swap_out_dead_queue(sycl_queue); @@ -564,29 +712,71 @@ int test(sycl::device* dev) { return !::testing::Test::HasFailure(); } -class DescriptorTests : public ::testing::TestWithParam {}; +TEST(DescriptorTests, DescriptorMoveRealSingle) { + EXPECT_TRUE((test_move())); +} + +TEST(DescriptorTests, DescriptorMoveRealDouble) { + EXPECT_TRUE((test_move())); +} + +TEST(DescriptorTests, DescriptorMoveComplexSingle) { + EXPECT_TRUE( + (test_move())); +} + +TEST(DescriptorTests, DescriptorMoveComplexDouble) { + EXPECT_TRUE( + (test_move())); +} + +TEST(DescriptorTests, DescriptorTestsRealSingle) { + EXPECT_TRUE(( + test_getter_setter())); +} + +TEST(DescriptorTests, DescriptorTestsRealDouble) { + EXPECT_TRUE(( + test_getter_setter())); +} + +TEST(DescriptorTests, DescriptorTestsComplexSingle) { + EXPECT_TRUE((test_getter_setter())); +} + +TEST(DescriptorTests, DescriptorTestsComplexDouble) { + EXPECT_TRUE((test_getter_setter())); +} + +class DescriptorCommitTests : public ::testing::TestWithParam {}; -TEST_P(DescriptorTests, DescriptorTestsRealSingle) { +TEST_P(DescriptorCommitTests, DescriptorCommitTestsRealSingle) { EXPECT_TRUEORSKIP( - (test(GetParam()))); + (test_commit( + GetParam()))); } -TEST_P(DescriptorTests, DescriptorTestsRealDouble) { +TEST_P(DescriptorCommitTests, DescriptorCommitTestsRealDouble) { EXPECT_TRUEORSKIP( - (test(GetParam()))); + (test_commit( + GetParam()))); } -TEST_P(DescriptorTests, DescriptorTestsComplexSingle) { +TEST_P(DescriptorCommitTests, DescriptorCommitTestsComplexSingle) { EXPECT_TRUEORSKIP( - (test(GetParam()))); + (test_commit( + GetParam()))); } -TEST_P(DescriptorTests, DescriptorTestsComplexDouble) { +TEST_P(DescriptorCommitTests, DescriptorCommitTestsComplexDouble) { EXPECT_TRUEORSKIP( - (test(GetParam()))); + (test_commit( + GetParam()))); } -INSTANTIATE_TEST_SUITE_P(DescriptorTestSuite, DescriptorTests, testing::ValuesIn(devices), - ::DeviceNamePrint()); +INSTANTIATE_TEST_SUITE_P(DescriptorCommitTestSuite, DescriptorCommitTests, + testing::ValuesIn(devices), ::DeviceNamePrint()); } // anonymous namespace diff --git a/tests/unit_tests/dft/source/workspace_external_tests.cpp b/tests/unit_tests/dft/source/workspace_external_tests.cpp new file mode 100644 index 000000000..f96544a90 --- /dev/null +++ b/tests/unit_tests/dft/source/workspace_external_tests.cpp @@ -0,0 +1,403 @@ +/*************************************************************************** +* Copyright (C) Codeplay Software Limited +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* For your convenience, a copy of the License has been included in this +* repository. +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +* +**************************************************************************/ + +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "test_helper.hpp" +#include "test_common.hpp" +#include "parseval_check.hpp" +#include + +extern std::vector devices; + +class WorkspaceExternalTests : public ::testing::TestWithParam {}; + +template +int test_workspace_external_usm_impl(std::size_t dft_size, sycl::device* dev) { + using namespace oneapi::mkl::dft; + using scalar_t = std::conditional_t; + using forward_t = std::conditional_t, scalar_t>; + using backward_t = std::complex; + + sycl::queue sycl_queue(*dev); + if (prec == precision::DOUBLE && !sycl_queue.get_device().has(sycl::aspect::fp64)) { + std::cout << "Device does not support double precision." << std::endl; + return test_skipped; + } + descriptor desc(static_cast(dft_size)); + + desc.set_value(config_param::WORKSPACE_PLACEMENT, config_value::WORKSPACE_EXTERNAL); + desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE); + try { + commit_descriptor(desc, sycl_queue); + } + catch (oneapi::mkl::unimplemented&) { + std::cout << "Test configuration not implemented." << std::endl; + return test_skipped; + } + std::int64_t workspace_bytes = -1; + desc.get_value(config_param::WORKSPACE_EXTERNAL_BYTES, &workspace_bytes); + if (workspace_bytes < 0) { + return ::testing::Test::HasFailure(); + } + scalar_t* workspace = sycl::malloc_device( + static_cast(workspace_bytes) / sizeof(scalar_t), sycl_queue); + desc.set_workspace(workspace); + // Generate data + std::vector host_fwd(static_cast(dft_size)); + std::size_t bwd_size = dom == domain::COMPLEX ? dft_size : dft_size / 2 + 1; + std::vector host_bwd(bwd_size); + rand_vector(host_fwd, dft_size); + + // Allocate enough memory that we don't have to worry about the domain. + forward_t* device_fwd = sycl::malloc_device(dft_size, sycl_queue); + backward_t* deviceBwd = sycl::malloc_device(bwd_size, sycl_queue); + sycl_queue.copy(host_fwd.data(), device_fwd, dft_size); + sycl_queue.wait_and_throw(); + + compute_forward(desc, device_fwd, deviceBwd); + sycl_queue.wait_and_throw(); + + sycl_queue.copy(deviceBwd, host_bwd.data(), bwd_size); + sycl_queue.wait_and_throw(); + + // To see external workspaces, larger sizes of DFT may be needed. Using the reference DFT with larger sizes is slow, + // so use Parseval's theorum as a sanity check instead. + bool sanityCheckPasses = parseval_check(dft_size, host_fwd.data(), host_bwd.data()); + + if (sanityCheckPasses) { + sycl_queue.copy(host_fwd.data(), device_fwd, dft_size); + sycl_queue.wait_and_throw(); + compute_backward(desc, deviceBwd, device_fwd); + sycl_queue.wait_and_throw(); + sycl_queue.copy(device_fwd, host_fwd.data(), dft_size); + sycl_queue.wait_and_throw(); + forward_t rescale = + static_cast(1) / static_cast(static_cast(dft_size)); + sanityCheckPasses = parseval_check(dft_size, host_fwd.data(), host_bwd.data(), rescale); + } + + sycl::free(device_fwd, sycl_queue); + sycl::free(deviceBwd, sycl_queue); + sycl::free(workspace, sycl_queue); + return sanityCheckPasses ? !::testing::Test::HasFailure() : ::testing::Test::HasFailure(); +} + +template +int test_workspace_external_buffer_impl(std::size_t dft_size, sycl::device* dev) { + using namespace oneapi::mkl::dft; + using scalar_t = std::conditional_t; + using forward_t = std::conditional_t, scalar_t>; + using backward_t = std::complex; + + sycl::queue sycl_queue(*dev); + if (prec == precision::DOUBLE && !sycl_queue.get_device().has(sycl::aspect::fp64)) { + std::cout << "Device does not support double precision." << std::endl; + return test_skipped; + } + descriptor desc(static_cast(dft_size)); + + desc.set_value(config_param::WORKSPACE_PLACEMENT, config_value::WORKSPACE_EXTERNAL); + desc.set_value(config_param::PLACEMENT, config_value::NOT_INPLACE); + try { + commit_descriptor(desc, sycl_queue); + } + catch (oneapi::mkl::unimplemented&) { + std::cout << "Test configuration not implemented." << std::endl; + return test_skipped; + } + std::int64_t workspace_bytes = -1; + desc.get_value(config_param::WORKSPACE_EXTERNAL_BYTES, &workspace_bytes); + if (workspace_bytes < 0) { + return ::testing::Test::HasFailure(); + } + sycl::buffer workspace(static_cast(workspace_bytes) / sizeof(scalar_t)); + desc.set_workspace(workspace); + // Generate data + std::vector host_fwd(static_cast(dft_size)); + std::size_t bwd_size = + dom == domain::COMPLEX ? dft_size : dft_size / 2 + 1; // TODO: Check this! + std::vector host_bwd(bwd_size); + rand_vector(host_fwd, dft_size); + auto host_fwdCpy = host_fwd; // Some backends modify the input data (rocFFT). + + { + sycl::buffer buf_fwd(host_fwd); + sycl::buffer buf_bwd(host_bwd); + compute_forward(desc, buf_fwd, buf_bwd); + } + + // To see external workspaces, larger sizes of DFT may be needed. Using the reference DFT with larger sizes is slow, + // so use Parseval's theorum as a sanity check instead. + bool sanityCheckPasses = parseval_check(dft_size, host_fwdCpy.data(), host_bwd.data()); + + if (sanityCheckPasses) { + auto host_bwdCpy = host_bwd; + { + sycl::buffer buf_fwd(host_fwd); + sycl::buffer buf_bwd(host_bwd); + compute_backward(desc, buf_bwd, buf_fwd); + sycl_queue.wait_and_throw(); + } + forward_t rescale = + static_cast(1) / static_cast(static_cast(dft_size)); + sanityCheckPasses = parseval_check(dft_size, host_fwd.data(), host_bwdCpy.data(), rescale); + } + + return sanityCheckPasses ? !::testing::Test::HasFailure() : ::testing::Test::HasFailure(); +} + +template +void test_workspace_external_usm(sycl::device* dev) { + EXPECT_TRUEORSKIP((test_workspace_external_usm_impl(2, dev))); + EXPECT_TRUEORSKIP((test_workspace_external_usm_impl(1024 * 3 * 5 * 7 * 16, dev))); +} + +template +void test_workspace_external_buffer(sycl::device* dev) { + EXPECT_TRUEORSKIP((test_workspace_external_buffer_impl(2, dev))); + EXPECT_TRUEORSKIP((test_workspace_external_buffer_impl(1024 * 3 * 5 * 7 * 16, dev))); +} + +TEST_P(WorkspaceExternalTests, TestWorkspaceExternalSingleUsm) { + using precision = oneapi::mkl::dft::precision; + using domain = oneapi::mkl::dft::domain; + test_workspace_external_usm(GetParam()); + test_workspace_external_usm(GetParam()); +} + +TEST_P(WorkspaceExternalTests, TestWorkspaceExternalDoubleUsm) { + using precision = oneapi::mkl::dft::precision; + using domain = oneapi::mkl::dft::domain; + test_workspace_external_usm(GetParam()); + test_workspace_external_usm(GetParam()); +} + +TEST_P(WorkspaceExternalTests, TestWorkspaceExternalSingleBuffer) { + using precision = oneapi::mkl::dft::precision; + using domain = oneapi::mkl::dft::domain; + test_workspace_external_buffer(GetParam()); + test_workspace_external_buffer(GetParam()); +} + +TEST_P(WorkspaceExternalTests, TestWorkspaceExternalDoubleBuffer) { + using precision = oneapi::mkl::dft::precision; + using domain = oneapi::mkl::dft::domain; + test_workspace_external_buffer(GetParam()); + test_workspace_external_buffer(GetParam()); +} + +/// A test where set_workspace is called when an external workspace is not set. +TEST_P(WorkspaceExternalTests, SetWorkspaceOnWorkspaceAutomatic) { + using namespace oneapi::mkl::dft; + sycl::queue sycl_queue(*GetParam()); + const int dft_len = 1024 * 3 * 5 * 7 * 16; // A size likely to require an external workspace. + float* fft_data_usm = sycl::malloc_device(dft_len * 2, sycl_queue); + sycl::buffer fft_data_buf(dft_len * 2); + descriptor desc_usm(dft_len), desc_buf(dft_len); + try { + // WORKSPACE_EXTERNAL is NOT set. + commit_descriptor(desc_usm, sycl_queue); + commit_descriptor(desc_buf, sycl_queue); + } + catch (oneapi::mkl::unimplemented&) { + // The DFT size may not be supported. Use a size that is likely to be supported, even if + // that means no external workspace is actually used. + descriptor desc_usm2(2), desc_buf2(2); + desc_usm = std::move(desc_usm2); + desc_buf = std::move(desc_buf2); + commit_descriptor(desc_usm, sycl_queue); + commit_descriptor(desc_buf, sycl_queue); + } + std::int64_t workspace_bytes = 0; + desc_usm.get_value(config_param::WORKSPACE_EXTERNAL_BYTES, &workspace_bytes); + + // No workspace set yet: all of the following should work. + compute_forward(desc_usm, fft_data_usm); + compute_forward(desc_buf, fft_data_buf); + compute_backward(desc_usm, fft_data_usm); + compute_backward(desc_buf, fft_data_buf); + compute_forward(desc_usm, fft_data_buf); + compute_forward(desc_buf, fft_data_usm); + compute_backward(desc_usm, fft_data_buf); + compute_backward(desc_buf, fft_data_usm); + sycl_queue.wait_and_throw(); + + // Set workspace + float* usm_workspace = sycl::malloc_device( + static_cast(workspace_bytes) / sizeof(float), sycl_queue); + sycl::buffer bufferWorkspace(static_cast(workspace_bytes) / sizeof(float)); + desc_usm.set_workspace(usm_workspace); + desc_buf.set_workspace(bufferWorkspace); + + // Should work: + compute_forward(desc_usm, fft_data_usm); + sycl_queue.wait_and_throw(); + compute_forward(desc_buf, fft_data_buf); + sycl_queue.wait_and_throw(); + compute_backward(desc_usm, fft_data_usm); + sycl_queue.wait_and_throw(); + compute_backward(desc_buf, fft_data_buf); + sycl_queue.wait_and_throw(); + + // Should not work: + EXPECT_THROW(compute_forward(desc_usm, fft_data_buf), oneapi::mkl::invalid_argument); + EXPECT_THROW(compute_forward(desc_buf, fft_data_usm), oneapi::mkl::invalid_argument); + EXPECT_THROW(compute_backward(desc_usm, fft_data_buf), oneapi::mkl::invalid_argument); + EXPECT_THROW(compute_backward(desc_buf, fft_data_usm), oneapi::mkl::invalid_argument); + sycl_queue.wait_and_throw(); + + // Free any allocations: + sycl::free(usm_workspace, sycl_queue); + sycl::free(fft_data_usm, sycl_queue); +} + +/// Test that the implementation throws as expected. +TEST_P(WorkspaceExternalTests, ThrowOnBadCalls) { + using namespace oneapi::mkl::dft; + sycl::queue sycl_queue(*GetParam()); + const int dft_len = 1024 * 3 * 5 * 7 * 16; // A size likely to require an external workspace. + float* fft_data_usm = sycl::malloc_device(dft_len * 2, sycl_queue); + sycl::buffer fft_data_buf(dft_len * 2); + descriptor desc_usm(dft_len), desc_buf(dft_len); + desc_usm.set_value(config_param::WORKSPACE_PLACEMENT, config_value::WORKSPACE_EXTERNAL); + desc_buf.set_value(config_param::WORKSPACE_PLACEMENT, config_value::WORKSPACE_EXTERNAL); + // We expect the following to throw because the decriptor has not been committed. + std::int64_t workspace_bytes = -10; + float* usm_workspace = nullptr; + EXPECT_THROW(desc_usm.get_value(config_param::WORKSPACE_EXTERNAL_BYTES, &workspace_bytes), + oneapi::mkl::invalid_argument); + EXPECT_THROW(desc_usm.set_workspace(usm_workspace), oneapi::mkl::uninitialized); + try { + commit_descriptor(desc_usm, sycl_queue); + commit_descriptor(desc_buf, sycl_queue); + } + catch (oneapi::mkl::unimplemented&) { + // DFT size may not be supported. Use a DFT size that probably will be, even if it + // won't actually use an external workspace internally. + descriptor desc_usm2(2), desc_buf2(2); + desc_usm = std::move(desc_usm2); + desc_buf = std::move(desc_buf2); + desc_usm.set_value(config_param::WORKSPACE_PLACEMENT, config_value::WORKSPACE_EXTERNAL); + desc_buf.set_value(config_param::WORKSPACE_PLACEMENT, config_value::WORKSPACE_EXTERNAL); + commit_descriptor(desc_usm, sycl_queue); + commit_descriptor(desc_buf, sycl_queue); + } + + desc_usm.get_value(config_param::WORKSPACE_EXTERNAL_BYTES, &workspace_bytes); + EXPECT_GE(workspace_bytes, 0); + + // We haven't set a workspace, so the following should fail; + EXPECT_THROW(compute_forward(desc_usm, fft_data_usm), oneapi::mkl::invalid_argument); + sycl_queue.wait_and_throw(); + EXPECT_THROW(compute_forward(desc_usm, fft_data_buf), oneapi::mkl::invalid_argument); + sycl_queue.wait_and_throw(); + + if (workspace_bytes > 0) { + EXPECT_THROW(desc_usm.set_workspace(nullptr), oneapi::mkl::invalid_argument); + sycl::buffer undersize_workspace( + static_cast(workspace_bytes) / sizeof(float) - 1); + EXPECT_THROW(desc_buf.set_workspace(undersize_workspace), oneapi::mkl::invalid_argument); + } + + usm_workspace = sycl::malloc_device( + static_cast(workspace_bytes) / sizeof(float), sycl_queue); + sycl::buffer bufferWorkspace(static_cast(workspace_bytes) / sizeof(float)); + + desc_usm.set_workspace(usm_workspace); + desc_buf.set_workspace(bufferWorkspace); + + // Should work: + compute_forward(desc_usm, fft_data_usm); + sycl_queue.wait_and_throw(); + compute_forward(desc_buf, fft_data_buf); + sycl_queue.wait_and_throw(); + compute_backward(desc_usm, fft_data_usm); + sycl_queue.wait_and_throw(); + compute_backward(desc_buf, fft_data_buf); + sycl_queue.wait_and_throw(); + + // Should not work: + EXPECT_THROW(compute_forward(desc_usm, fft_data_buf), oneapi::mkl::invalid_argument); + EXPECT_THROW(compute_forward(desc_buf, fft_data_usm), oneapi::mkl::invalid_argument); + EXPECT_THROW(compute_backward(desc_usm, fft_data_buf), oneapi::mkl::invalid_argument); + EXPECT_THROW(compute_backward(desc_buf, fft_data_usm), oneapi::mkl::invalid_argument); + sycl_queue.wait_and_throw(); + + // Free any allocations: + sycl::free(usm_workspace, sycl_queue); + sycl::free(fft_data_usm, sycl_queue); +} + +TEST_P(WorkspaceExternalTests, RecommitBehaviour) { + using namespace oneapi::mkl::dft; + sycl::queue sycl_queue(*GetParam()); + const int dft_len = 1024 * 3 * 5 * 7 * 16; // A size likely to require an external workspace. + float* fft_data_usm = sycl::malloc_device(dft_len * 2, sycl_queue); + descriptor desc_usm(dft_len); + try { + // WORKSPACE_EXTERNAL is NOT set. + commit_descriptor(desc_usm, sycl_queue); + } + catch (oneapi::mkl::unimplemented&) { + // DFT size may not be supported. Use a DFT size that probably will be, even if it + // won't actually use an external workspace internally. + descriptor desc_usm2(2); + desc_usm = std::move(desc_usm2); + commit_descriptor(desc_usm, sycl_queue); + } + std::int64_t workspace_bytes = 0; + desc_usm.get_value(config_param::WORKSPACE_EXTERNAL_BYTES, &workspace_bytes); + float* usm_workspace = sycl::malloc_device( + static_cast(workspace_bytes) / sizeof(float), sycl_queue); + + // Should work with workspace automatic + compute_forward(desc_usm, fft_data_usm); + sycl_queue.wait_and_throw(); + + desc_usm.set_value(config_param::WORKSPACE_PLACEMENT, config_value::WORKSPACE_EXTERNAL); + commit_descriptor(desc_usm, sycl_queue); + + // No workspace, expect throw + EXPECT_THROW(compute_forward(desc_usm, fft_data_usm), oneapi::mkl::invalid_argument); + + desc_usm.set_workspace(usm_workspace); + + compute_forward(desc_usm, fft_data_usm); + sycl_queue.wait_and_throw(); + + // Recommitting should require workspace to be set again. + commit_descriptor(desc_usm, sycl_queue); + EXPECT_THROW(compute_forward(desc_usm, fft_data_usm), oneapi::mkl::invalid_argument); + sycl_queue.wait_and_throw(); + + // Free any allocations: + sycl::free(usm_workspace, sycl_queue); + sycl::free(fft_data_usm, sycl_queue); +} + +INSTANTIATE_TEST_SUITE_P(WorkspaceExternalTestSuite, WorkspaceExternalTests, + testing::ValuesIn(devices), ::DeviceNamePrint()); diff --git a/tests/unit_tests/include/test_helper.hpp b/tests/unit_tests/include/test_helper.hpp index eeb921274..7e0024195 100644 --- a/tests/unit_tests/include/test_helper.hpp +++ b/tests/unit_tests/include/test_helper.hpp @@ -44,6 +44,9 @@ #define test_passed 1 #define test_skipped 2 +// Note GTEST_SKIP may not print the associated message when using ctest. +// However, running a test binary with the flag `--terse-output` will print them. + #define EXPECT_TRUEORSKIP(a) \ do { \ int res = a; \ @@ -53,6 +56,23 @@ EXPECT_EQ(res, test_passed); \ } while (0); +// GTEST_SKIP stops the execution of the program. +// This macro lets a test use multiple EXPECT_TRUE_OR_FUTURE_SKIP and mark a test as skipped only once at the end. +#define EXPECT_TRUE_OR_FUTURE_SKIP(a, num_passed, num_skipped) \ + do { \ + int res = a; \ + if (res == test_skipped) \ + ++num_skipped; \ + else { \ + ++num_passed; \ + EXPECT_EQ(res, test_passed); \ + } \ + } while (0); + +#define CHECK_DOUBLE_ON_DEVICE(d) \ + if (d->get_info().size() == 0) \ + GTEST_SKIP() << "Double precision is not supported on the device" + #if defined(ENABLE_MKLCPU_BACKEND) || defined(ENABLE_NETLIB_BACKEND) #ifdef ENABLE_MKLCPU_BACKEND #define TEST_RUN_INTELCPU_SELECT_NO_ARGS(q, func) \ @@ -119,6 +139,43 @@ #define TEST_RUN_AMDGPU_ROCSOLVER_SELECT(q, func, ...) #endif +#ifdef ENABLE_PORTBLAS_BACKEND +#define TEST_RUN_PORTBLAS_SELECT(q, func, ...) \ + func(oneapi::mkl::backend_selector{ q }, __VA_ARGS__) +#else +#define TEST_RUN_PORTBLAS_SELECT(q, func, ...) +#endif + +#ifdef ENABLE_CUFFT_BACKEND +#define TEST_RUN_NVIDIAGPU_CUFFT_SELECT_NO_ARGS(q, func) \ + func(oneapi::mkl::backend_selector{ q }) +#define TEST_RUN_NVIDIAGPU_CUFFT_SELECT(q, func, ...) \ + func(oneapi::mkl::backend_selector{ q }, __VA_ARGS__) +#else +#define TEST_RUN_NVIDIAGPU_CUFFT_SELECT_NO_ARGS(q, func) +#define TEST_RUN_NVIDIAGPU_CUFFT_SELECT(q, func, ...) +#endif + +#ifdef ENABLE_ROCFFT_BACKEND +#define TEST_RUN_AMDGPU_ROCFFT_SELECT_NO_ARGS(q, func) \ + func(oneapi::mkl::backend_selector{ q }) +#define TEST_RUN_AMDGPU_ROCFFT_SELECT(q, func, ...) \ + func(oneapi::mkl::backend_selector{ q }, __VA_ARGS__) +#else +#define TEST_RUN_AMDGPU_ROCFFT_SELECT_NO_ARGS(q, func) +#define TEST_RUN_AMDGPU_ROCFFT_SELECT(q, func, ...) +#endif + +#ifdef ENABLE_PORTFFT_BACKEND +#define TEST_RUN_PORTFFT_SELECT_NO_ARGS(q, func) \ + func(oneapi::mkl::backend_selector{ q }) +#define TEST_RUN_PORTFFT_SELECT(q, func, ...) \ + func(oneapi::mkl::backend_selector{ q }, __VA_ARGS__) +#else +#define TEST_RUN_PORTFFT_SELECT_NO_ARGS(q, func) +#define TEST_RUN_PORTFFT_SELECT(q, func, ...) +#endif + #ifndef __HIPSYCL__ #define CHECK_HOST_OR_CPU(q) q.get_device().is_cpu() #else @@ -136,7 +193,14 @@ if (vendor_id == INTEL_ID) { \ TEST_RUN_INTELGPU_SELECT_NO_ARGS(q, func); \ } \ + else if (vendor_id == NVIDIA_ID) { \ + TEST_RUN_NVIDIAGPU_CUFFT_SELECT_NO_ARGS(q, func); \ + } \ + else if (vendor_id == AMD_ID) { \ + TEST_RUN_AMDGPU_ROCFFT_SELECT_NO_ARGS(q, func); \ + } \ } \ + TEST_RUN_PORTFFT_SELECT_NO_ARGS(q, func); \ } while (0); #define TEST_RUN_CT_SELECT(q, func, ...) \ @@ -157,6 +221,64 @@ TEST_RUN_AMDGPU_ROCBLAS_SELECT(q, func, __VA_ARGS__); \ TEST_RUN_AMDGPU_ROCRAND_SELECT(q, func, __VA_ARGS__); \ TEST_RUN_AMDGPU_ROCSOLVER_SELECT(q, func, __VA_ARGS__); \ + TEST_RUN_AMDGPU_ROCFFT_SELECT(q, func, __VA_ARGS__); \ + } \ + } \ + TEST_RUN_PORTBLAS_SELECT(q, func, __VA_ARGS__); \ + TEST_RUN_PORTFFT_SELECT(q, func, __VA_ARGS__); \ + } while (0); + +#define TEST_RUN_BLAS_CT_SELECT(q, func, ...) \ + do { \ + if (CHECK_HOST_OR_CPU(q)) \ + TEST_RUN_INTELCPU_SELECT(q, func, __VA_ARGS__); \ + else if (q.get_device().is_gpu()) { \ + unsigned int vendor_id = static_cast( \ + q.get_device().get_info()); \ + if (vendor_id == INTEL_ID) \ + TEST_RUN_INTELGPU_SELECT(q, func, __VA_ARGS__); \ + else if (vendor_id == NVIDIA_ID) { \ + TEST_RUN_NVIDIAGPU_CUBLAS_SELECT(q, func, __VA_ARGS__); \ + } \ + else if (vendor_id == AMD_ID) { \ + TEST_RUN_AMDGPU_ROCBLAS_SELECT(q, func, __VA_ARGS__); \ + } \ + } \ + TEST_RUN_PORTBLAS_SELECT(q, func, __VA_ARGS__); \ + } while (0); + +#define TEST_RUN_RNG_CT_SELECT(q, func, ...) \ + do { \ + if (CHECK_HOST_OR_CPU(q)) \ + TEST_RUN_INTELCPU_SELECT(q, func, __VA_ARGS__); \ + else if (q.get_device().is_gpu()) { \ + unsigned int vendor_id = static_cast( \ + q.get_device().get_info()); \ + if (vendor_id == INTEL_ID) \ + TEST_RUN_INTELGPU_SELECT(q, func, __VA_ARGS__); \ + else if (vendor_id == NVIDIA_ID) { \ + TEST_RUN_NVIDIAGPU_CURAND_SELECT(q, func, __VA_ARGS__); \ + } \ + else if (vendor_id == AMD_ID) { \ + TEST_RUN_AMDGPU_ROCRAND_SELECT(q, func, __VA_ARGS__); \ + } \ + } \ + } while (0); + +#define TEST_RUN_LAPACK_CT_SELECT(q, func, ...) \ + do { \ + if (CHECK_HOST_OR_CPU(q)) \ + TEST_RUN_INTELCPU_SELECT(q, func, __VA_ARGS__); \ + else if (q.get_device().is_gpu()) { \ + unsigned int vendor_id = static_cast( \ + q.get_device().get_info()); \ + if (vendor_id == INTEL_ID) \ + TEST_RUN_INTELGPU_SELECT(q, func, __VA_ARGS__); \ + else if (vendor_id == NVIDIA_ID) { \ + TEST_RUN_NVIDIAGPU_CUSOLVER_SELECT(q, func, __VA_ARGS__); \ + } \ + else if (vendor_id == AMD_ID) { \ + TEST_RUN_AMDGPU_ROCSOLVER_SELECT(q, func, __VA_ARGS__); \ } \ } \ } while (0); @@ -181,9 +303,8 @@ class LayoutDeviceNamePrint { public: std::string operator()( testing::TestParamInfo> dev) const { - std::string layout_name = std::get<1>(dev.param) == oneapi::mkl::layout::column_major - ? "Column_Major" - : "Row_Major"; + std::string layout_name = + std::get<1>(dev.param) == oneapi::mkl::layout::col_major ? "Column_Major" : "Row_Major"; std::string dev_name = std::get<0>(dev.param)->get_info(); for (std::string::size_type i = 0; i < dev_name.size(); ++i) { if (!isalnum(dev_name[i])) @@ -217,6 +338,7 @@ static inline void aligned_free(void *p) { /* Support for Unified Shared Memory allocations for different backends */ static inline void *malloc_shared(size_t align, size_t size, sycl::device dev, sycl::context ctx) { + (void)align; #ifdef _WIN64 return sycl::malloc_shared(size, dev, ctx); #else @@ -229,10 +351,28 @@ static inline void *malloc_shared(size_t align, size_t size, sycl::device dev, s #endif } +static inline void *malloc_device(size_t align, size_t size, sycl::device dev, sycl::context ctx) { + (void)align; +#ifdef _WIN64 + return sycl::malloc_device(size, dev, ctx); +#else +#if defined(ENABLE_CUBLAS_BACKEND) || defined(ENABLE_ROCBLAS_BACKEND) + return sycl::aligned_alloc_device(align, size, dev, ctx); +#endif +#if !defined(ENABLE_CUBLAS_BACKEND) && !defined(ENABLE_ROCBLAS_BACKEND) + return sycl::malloc_device(size, dev, ctx); +#endif +#endif +} + static inline void free_shared(void *p, sycl::context ctx) { sycl::free(p, ctx); } +static inline void free_usm(void *p, sycl::context ctx) { + sycl::free(p, ctx); +} + } // namespace mkl } // namespace oneapi diff --git a/tests/unit_tests/lapack/include/lapack_gtest_suite.hpp b/tests/unit_tests/lapack/include/lapack_gtest_suite.hpp index a04543410..41e349d7c 100644 --- a/tests/unit_tests/lapack/include/lapack_gtest_suite.hpp +++ b/tests/unit_tests/lapack/include/lapack_gtest_suite.hpp @@ -86,6 +86,7 @@ using ComplexDoublePrecisionUsm = std::complex; EXPECT_TRUE(accuracy_controller.run(::accuracy, *GetParam())); \ } \ TEST_P(SUITE##AccuracyUsm, RealDoublePrecision) { \ + CHECK_DOUBLE_ON_DEVICE(GetParam()); \ test_log::padding = "[ ] "; \ EXPECT_TRUE(accuracy_controller.run(::accuracy, *GetParam())); \ } @@ -96,6 +97,7 @@ using ComplexDoublePrecisionUsm = std::complex; EXPECT_TRUE(accuracy_controller.run(::accuracy, *GetParam())); \ } \ TEST_P(SUITE##AccuracyUsm, ComplexDoublePrecision) { \ + CHECK_DOUBLE_ON_DEVICE(GetParam()); \ test_log::padding = "[ ] "; \ EXPECT_TRUE(accuracy_controller.run(::accuracy, *GetParam())); \ } @@ -106,6 +108,7 @@ using ComplexDoublePrecisionUsm = std::complex; EXPECT_TRUE(accuracy_controller.run(::accuracy, *GetParam())); \ } \ TEST_P(SUITE##AccuracyBuffer, RealDoublePrecision) { \ + CHECK_DOUBLE_ON_DEVICE(GetParam()); \ test_log::padding = "[ ] "; \ EXPECT_TRUE(accuracy_controller.run(::accuracy, *GetParam())); \ } @@ -117,6 +120,7 @@ using ComplexDoublePrecisionUsm = std::complex; accuracy_controller.run(::accuracy, *GetParam())); \ } \ TEST_P(SUITE##AccuracyBuffer, ComplexDoublePrecision) { \ + CHECK_DOUBLE_ON_DEVICE(GetParam()); \ test_log::padding = "[ ] "; \ EXPECT_TRUE( \ accuracy_controller.run(::accuracy, *GetParam())); \ @@ -145,6 +149,7 @@ using ComplexDoublePrecisionUsm = std::complex; dependency_controller.run(::usm_dependency, *GetParam())); \ } \ TEST_P(SUITE##DependencyUsm, RealDoublePrecision) { \ + CHECK_DOUBLE_ON_DEVICE(GetParam()); \ test_log::padding = "[ ] "; \ EXPECT_TRUE( \ dependency_controller.run(::usm_dependency, *GetParam())); \ @@ -157,6 +162,7 @@ using ComplexDoublePrecisionUsm = std::complex; dependency_controller.run(::usm_dependency, *GetParam())); \ } \ TEST_P(SUITE##DependencyUsm, ComplexDoublePrecision) { \ + CHECK_DOUBLE_ON_DEVICE(GetParam()); \ test_log::padding = "[ ] "; \ EXPECT_TRUE( \ dependency_controller.run(::usm_dependency, *GetParam())); \ diff --git a/tests/unit_tests/lapack/source/gebrd.cpp b/tests/unit_tests/lapack/source/gebrd.cpp index 71ee58af0..66eb0b231 100644 --- a/tests/unit_tests/lapack/source/gebrd.cpp +++ b/tests/unit_tests/lapack/source/gebrd.cpp @@ -73,8 +73,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 oneapi::mkl::lapack::gebrd_scratchpad_size(queue, m, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::gebrd_scratchpad_size, - m, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::gebrd_scratchpad_size, m, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -85,8 +85,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 oneapi::mkl::lapack::gebrd(queue, m, n, A_dev, lda, d_dev, e_dev, tauq_dev, taup_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::gebrd, m, n, A_dev, lda, d_dev, e_dev, - tauq_dev, taup_dev, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::gebrd, m, n, A_dev, lda, d_dev, e_dev, + tauq_dev, taup_dev, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -146,8 +146,8 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, oneapi::mkl::lapack::gebrd_scratchpad_size(queue, m, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::gebrd_scratchpad_size, - m, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::gebrd_scratchpad_size, m, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -162,9 +162,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::gebrd, m, n, A_dev, lda, d_dev, - e_dev, tauq_dev, taup_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::gebrd, m, n, A_dev, lda, + d_dev, e_dev, tauq_dev, taup_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/geqrf.cpp b/tests/unit_tests/lapack/source/geqrf.cpp index 8d463292d..27577e972 100644 --- a/tests/unit_tests/lapack/source/geqrf.cpp +++ b/tests/unit_tests/lapack/source/geqrf.cpp @@ -65,8 +65,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 oneapi::mkl::lapack::geqrf_scratchpad_size(queue, m, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::geqrf_scratchpad_size, - m, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::geqrf_scratchpad_size, m, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -77,8 +77,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 oneapi::mkl::lapack::geqrf(queue, m, n, A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::geqrf, m, n, A_dev, lda, tau_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::geqrf, m, n, A_dev, lda, tau_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -122,8 +122,8 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, oneapi::mkl::lapack::geqrf_scratchpad_size(queue, m, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::geqrf_scratchpad_size, - m, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::geqrf_scratchpad_size, m, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -138,9 +138,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::geqrf, m, n, A_dev, lda, - tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::geqrf, m, n, A_dev, lda, + tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/geqrf_batch_group.cpp b/tests/unit_tests/lapack/source/geqrf_batch_group.cpp index 0cc5bbff2..416466028 100644 --- a/tests/unit_tests/lapack/source/geqrf_batch_group.cpp +++ b/tests/unit_tests/lapack/source/geqrf_batch_group.cpp @@ -81,8 +81,8 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> tau_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector tau_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** tau_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -99,7 +99,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT( + TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif @@ -120,13 +120,14 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue.wait_and_throw(); #ifdef CALL_RT_API - oneapi::mkl::lapack::geqrf_batch(queue, m_vec.data(), n_vec.data(), A_dev_ptrs.data(), - lda_vec.data(), tau_dev_ptrs.data(), group_count, + oneapi::mkl::lapack::geqrf_batch(queue, m_vec.data(), n_vec.data(), A_dev_ptrs, + lda_vec.data(), tau_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::geqrf_batch, m_vec.data(), n_vec.data(), - A_dev_ptrs.data(), lda_vec.data(), tau_dev_ptrs.data(), group_count, - group_sizes_vec.data(), scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::geqrf_batch, m_vec.data(), + n_vec.data(), A_dev_ptrs, lda_vec.data(), tau_dev_ptrs, + group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -137,6 +138,15 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { device_to_host_copy(queue, tau_dev_ptrs[global_id], tau_iter->data(), tau_iter->size()); } queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (tau_dev_ptrs) { + sycl::free(tau_dev_ptrs, queue); + } } bool result = true; @@ -209,8 +219,8 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> tau_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector tau_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** tau_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -227,7 +237,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT( + TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif @@ -251,19 +261,28 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { auto in_event = create_dependency(queue); #ifdef CALL_RT_API sycl::event func_event = oneapi::mkl::lapack::geqrf_batch( - queue, m_vec.data(), n_vec.data(), A_dev_ptrs.data(), lda_vec.data(), - tau_dev_ptrs.data(), group_count, group_sizes_vec.data(), scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + queue, m_vec.data(), n_vec.data(), A_dev_ptrs, lda_vec.data(), tau_dev_ptrs, + group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::geqrf_batch, m_vec.data(), - n_vec.data(), A_dev_ptrs.data(), lda_vec.data(), tau_dev_ptrs.data(), - group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::geqrf_batch, + m_vec.data(), n_vec.data(), A_dev_ptrs, lda_vec.data(), + tau_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (tau_dev_ptrs) { + sycl::free(tau_dev_ptrs, queue); + } } return result; diff --git a/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp b/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp index b1846bd56..16ceef63a 100644 --- a/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/geqrf_batch_stride.cpp @@ -65,9 +65,9 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, int64_ queue, m, n, lda, stride_a, stride_tau, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, - m, n, lda, stride_a, stride_tau, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m, n, + lda, stride_a, stride_tau, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -78,8 +78,9 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, int64_ oneapi::mkl::lapack::geqrf_batch(queue, m, n, A_dev, lda, stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::geqrf_batch, m, n, A_dev, lda, stride_a, - tau_dev, stride_tau, batch_size, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::geqrf_batch, m, n, A_dev, lda, + stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -137,9 +138,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, queue, m, n, lda, stride_a, stride_tau, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, - m, n, lda, stride_a, stride_tau, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::geqrf_batch_scratchpad_size, m, n, + lda, stride_a, stride_tau, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -154,9 +155,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::geqrf_batch, m, n, A_dev, lda, - stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::geqrf_batch, m, n, A_dev, + lda, stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/gerqf.cpp b/tests/unit_tests/lapack/source/gerqf.cpp index 19bf339c2..dac6d79aa 100644 --- a/tests/unit_tests/lapack/source/gerqf.cpp +++ b/tests/unit_tests/lapack/source/gerqf.cpp @@ -65,8 +65,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 oneapi::mkl::lapack::gerqf_scratchpad_size(queue, m, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::gerqf_scratchpad_size, - m, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::gerqf_scratchpad_size, m, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -77,8 +77,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 oneapi::mkl::lapack::gerqf(queue, m, n, A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::gerqf, m, n, A_dev, lda, tau_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::gerqf, m, n, A_dev, lda, tau_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -122,8 +122,8 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, oneapi::mkl::lapack::gerqf_scratchpad_size(queue, m, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::gerqf_scratchpad_size, - m, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::gerqf_scratchpad_size, m, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -138,9 +138,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::gerqf, m, n, A_dev, lda, - tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::gerqf, m, n, A_dev, lda, + tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/gesvd.cpp b/tests/unit_tests/lapack/source/gesvd.cpp index 6a98b7e23..1e143315b 100644 --- a/tests/unit_tests/lapack/source/gesvd.cpp +++ b/tests/unit_tests/lapack/source/gesvd.cpp @@ -75,8 +75,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::mkl::jo queue, jobu, jobvt, m, n, lda, ldu, ldvt); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::gesvd_scratchpad_size, - jobu, jobvt, m, n, lda, ldu, ldvt); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::gesvd_scratchpad_size, + jobu, jobvt, m, n, lda, ldu, ldvt); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -87,8 +88,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::mkl::jo oneapi::mkl::lapack::gesvd(queue, jobu, jobvt, m, n, A_dev, lda, s_dev, U_dev, ldu, Vt_dev, ldvt, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::gesvd, jobu, jobvt, m, n, A_dev, lda, s_dev, - U_dev, ldu, Vt_dev, ldvt, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::gesvd, jobu, jobvt, m, n, A_dev, lda, + s_dev, U_dev, ldu, Vt_dev, ldvt, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -196,8 +197,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::m queue, jobu, jobvt, m, n, lda, ldu, ldvt); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::gesvd_scratchpad_size, - jobu, jobvt, m, n, lda, ldu, ldvt); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::gesvd_scratchpad_size, + jobu, jobvt, m, n, lda, ldu, ldvt); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -212,9 +214,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::jobsvd jobu, oneapi::m scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::gesvd, jobu, jobvt, m, n, A_dev, - lda, s_dev, U_dev, ldu, Vt_dev, ldvt, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::gesvd, jobu, jobvt, m, n, + A_dev, lda, s_dev, U_dev, ldu, Vt_dev, ldvt, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/getrf.cpp b/tests/unit_tests/lapack/source/getrf.cpp index cd5c48996..4537ef665 100644 --- a/tests/unit_tests/lapack/source/getrf.cpp +++ b/tests/unit_tests/lapack/source/getrf.cpp @@ -68,8 +68,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 oneapi::mkl::lapack::getrf_scratchpad_size(queue, m, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size, - m, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size, m, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -80,8 +80,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, uint64 oneapi::mkl::lapack::getrf(queue, m, n, A_dev, lda, ipiv_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::getrf, m, n, A_dev, lda, ipiv_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::getrf, m, n, A_dev, lda, ipiv_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -125,8 +125,8 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, oneapi::mkl::lapack::getrf_scratchpad_size(queue, m, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size, - m, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getrf_scratchpad_size, m, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -141,9 +141,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrf, m, n, A_dev, lda, - ipiv_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrf, m, n, A_dev, lda, + ipiv_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/getrf_batch_group.cpp b/tests/unit_tests/lapack/source/getrf_batch_group.cpp index f50803d7d..12e651746 100644 --- a/tests/unit_tests/lapack/source/getrf_batch_group.cpp +++ b/tests/unit_tests/lapack/source/getrf_batch_group.cpp @@ -82,8 +82,8 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> ipiv_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector ipiv_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + int64_t** ipiv_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -103,7 +103,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT( + TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif @@ -124,13 +124,14 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue.wait_and_throw(); #ifdef CALL_RT_API - oneapi::mkl::lapack::getrf_batch(queue, m_vec.data(), n_vec.data(), A_dev_ptrs.data(), - lda_vec.data(), ipiv_dev_ptrs.data(), group_count, + oneapi::mkl::lapack::getrf_batch(queue, m_vec.data(), n_vec.data(), A_dev_ptrs, + lda_vec.data(), ipiv_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::getrf_batch, m_vec.data(), n_vec.data(), - A_dev_ptrs.data(), lda_vec.data(), ipiv_dev_ptrs.data(), group_count, - group_sizes_vec.data(), scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::getrf_batch, m_vec.data(), + n_vec.data(), A_dev_ptrs, lda_vec.data(), ipiv_dev_ptrs, + group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -142,6 +143,15 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { ipiv_iter->size()); } queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (ipiv_dev_ptrs) { + sycl::free(ipiv_dev_ptrs, queue); + } } bool result = true; @@ -215,8 +225,8 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> ipiv_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector ipiv_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + int64_t** ipiv_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -236,7 +246,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT( + TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif @@ -260,19 +270,28 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { auto in_event = create_dependency(queue); #ifdef CALL_RT_API sycl::event func_event = oneapi::mkl::lapack::getrf_batch( - queue, m_vec.data(), n_vec.data(), A_dev_ptrs.data(), lda_vec.data(), - ipiv_dev_ptrs.data(), group_count, group_sizes_vec.data(), scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + queue, m_vec.data(), n_vec.data(), A_dev_ptrs, lda_vec.data(), ipiv_dev_ptrs, + group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrf_batch, m_vec.data(), - n_vec.data(), A_dev_ptrs.data(), lda_vec.data(), ipiv_dev_ptrs.data(), - group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT( + queue, func_event = oneapi::mkl::lapack::getrf_batch, m_vec.data(), n_vec.data(), + A_dev_ptrs, lda_vec.data(), ipiv_dev_ptrs, group_count, group_sizes_vec.data(), + scratchpad_dev, scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (ipiv_dev_ptrs) { + sycl::free(ipiv_dev_ptrs, queue); + } } return result; diff --git a/tests/unit_tests/lapack/source/getrf_batch_stride.cpp b/tests/unit_tests/lapack/source/getrf_batch_stride.cpp index 7390a7b4a..3e4ef6589 100644 --- a/tests/unit_tests/lapack/source/getrf_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/getrf_batch_stride.cpp @@ -65,9 +65,9 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, int64_ queue, m, n, lda, stride_a, stride_ipiv, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, - m, n, lda, stride_a, stride_ipiv, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m, n, + lda, stride_a, stride_ipiv, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -78,8 +78,9 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, int64_ oneapi::mkl::lapack::getrf_batch(queue, m, n, A_dev, lda, stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::getrf_batch, m, n, A_dev, lda, stride_a, - ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::getrf_batch, m, n, A_dev, lda, + stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -137,9 +138,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, queue, m, n, lda, stride_a, stride_ipiv, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, - m, n, lda, stride_a, stride_ipiv, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getrf_batch_scratchpad_size, m, n, + lda, stride_a, stride_ipiv, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -154,9 +155,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t lda, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrf_batch, m, n, A_dev, lda, - stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrf_batch, m, n, A_dev, + lda, stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/getri.cpp b/tests/unit_tests/lapack/source/getri.cpp index 7aadca235..a1aa2deda 100644 --- a/tests/unit_tests/lapack/source/getri.cpp +++ b/tests/unit_tests/lapack/source/getri.cpp @@ -73,8 +73,8 @@ bool accuracy(const sycl::device& dev, int64_t n, int64_t lda, uint64_t seed) { const auto scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size(queue, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size, - n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -85,8 +85,8 @@ bool accuracy(const sycl::device& dev, int64_t n, int64_t lda, uint64_t seed) { #ifdef CALL_RT_API oneapi::mkl::lapack::getri(queue, n, A_dev, lda, ipiv_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::getri, n, A_dev, lda, ipiv_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::getri, n, A_dev, lda, ipiv_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -135,8 +135,8 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, uint64_t se const auto scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size(queue, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size, - n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getri_scratchpad_size, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -152,8 +152,9 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, uint64_t se scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getri, n, A_dev, lda, ipiv_dev, - scratchpad_dev, scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getri, n, A_dev, lda, + ipiv_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/getri_batch_group.cpp b/tests/unit_tests/lapack/source/getri_batch_group.cpp index dbda9ad40..244acfcc8 100644 --- a/tests/unit_tests/lapack/source/getri_batch_group.cpp +++ b/tests/unit_tests/lapack/source/getri_batch_group.cpp @@ -89,8 +89,8 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> ipiv_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector ipiv_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + int64_t** ipiv_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -110,9 +110,9 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue, n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, - n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, + n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -134,13 +134,13 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue.wait_and_throw(); #ifdef CALL_RT_API - oneapi::mkl::lapack::getri_batch(queue, n_vec.data(), A_dev_ptrs.data(), lda_vec.data(), - ipiv_dev_ptrs.data(), group_count, group_sizes_vec.data(), + oneapi::mkl::lapack::getri_batch(queue, n_vec.data(), A_dev_ptrs, lda_vec.data(), + ipiv_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::getri_batch, n_vec.data(), A_dev_ptrs.data(), - lda_vec.data(), ipiv_dev_ptrs.data(), group_count, - group_sizes_vec.data(), scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::getri_batch, n_vec.data(), A_dev_ptrs, + lda_vec.data(), ipiv_dev_ptrs, group_count, + group_sizes_vec.data(), scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -149,6 +149,15 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { device_to_host_copy(queue, A_dev_ptrs[global_id], A_iter->data(), A_iter->size()); } queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (ipiv_dev_ptrs) { + sycl::free(ipiv_dev_ptrs, queue); + } } bool result = true; @@ -228,8 +237,8 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> ipiv_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector ipiv_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + int64_t** ipiv_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -249,9 +258,9 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { queue, n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, - n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, + n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -276,19 +285,28 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { auto in_event = create_dependency(queue); #ifdef CALL_RT_API sycl::event func_event = oneapi::mkl::lapack::getri_batch( - queue, n_vec.data(), A_dev_ptrs.data(), lda_vec.data(), ipiv_dev_ptrs.data(), - group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, + queue, n_vec.data(), A_dev_ptrs, lda_vec.data(), ipiv_dev_ptrs, group_count, + group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getri_batch, n_vec.data(), - A_dev_ptrs.data(), lda_vec.data(), ipiv_dev_ptrs.data(), group_count, - group_sizes_vec.data(), scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getri_batch, + n_vec.data(), A_dev_ptrs, lda_vec.data(), ipiv_dev_ptrs, + group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (ipiv_dev_ptrs) { + sycl::free(ipiv_dev_ptrs, queue); + } } return result; diff --git a/tests/unit_tests/lapack/source/getri_batch_stride.cpp b/tests/unit_tests/lapack/source/getri_batch_stride.cpp index 9ccf5e629..5a71d2d7e 100644 --- a/tests/unit_tests/lapack/source/getri_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/getri_batch_stride.cpp @@ -72,9 +72,9 @@ bool accuracy(const sycl::device& dev, int64_t n, int64_t lda, int64_t stride_a, queue, n, lda, stride_a, stride_ipiv, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, - n, lda, stride_a, stride_ipiv, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n, lda, + stride_a, stride_ipiv, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -86,8 +86,9 @@ bool accuracy(const sycl::device& dev, int64_t n, int64_t lda, int64_t stride_a, oneapi::mkl::lapack::getri_batch(queue, n, A_dev, lda, stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::getri_batch, n, A_dev, lda, stride_a, - ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::getri_batch, n, A_dev, lda, stride_a, + ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -151,9 +152,9 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, int64_t str queue, n, lda, stride_a, stride_ipiv, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, - n, lda, stride_a, stride_ipiv, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getri_batch_scratchpad_size, n, lda, + stride_a, stride_ipiv, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -169,9 +170,9 @@ bool usm_dependency(const sycl::device& dev, int64_t n, int64_t lda, int64_t str scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getri_batch, n, A_dev, lda, - stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getri_batch, n, A_dev, + lda, stride_a, ipiv_dev, stride_ipiv, batch_size, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/getrs.cpp b/tests/unit_tests/lapack/source/getrs.cpp index 0b2056fa8..bfc271758 100644 --- a/tests/unit_tests/lapack/source/getrs.cpp +++ b/tests/unit_tests/lapack/source/getrs.cpp @@ -73,8 +73,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::transpose trans, int64_t n, oneapi::mkl::lapack::getrs_scratchpad_size(queue, trans, n, nrhs, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::getrs_scratchpad_size, - trans, n, nrhs, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::getrs_scratchpad_size, + trans, n, nrhs, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -87,8 +88,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::transpose trans, int64_t n, oneapi::mkl::lapack::getrs(queue, trans, n, nrhs, A_dev, lda, ipiv_dev, B_dev, ldb, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::getrs, trans, n, nrhs, A_dev, lda, ipiv_dev, - B_dev, ldb, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::getrs, trans, n, nrhs, A_dev, lda, + ipiv_dev, B_dev, ldb, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -144,8 +145,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 oneapi::mkl::lapack::getrs_scratchpad_size(queue, trans, n, nrhs, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::getrs_scratchpad_size, - trans, n, nrhs, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::getrs_scratchpad_size, + trans, n, nrhs, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -162,9 +164,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrs, trans, n, nrhs, A_dev, - lda, ipiv_dev, B_dev, ldb, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrs, trans, n, nrhs, + A_dev, lda, ipiv_dev, B_dev, ldb, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/getrs_batch_group.cpp b/tests/unit_tests/lapack/source/getrs_batch_group.cpp index 2ca1c28fb..2027663e4 100644 --- a/tests/unit_tests/lapack/source/getrs_batch_group.cpp +++ b/tests/unit_tests/lapack/source/getrs_batch_group.cpp @@ -106,9 +106,9 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { std::list>> B_dev_list; std::list>> ipiv_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector B_dev_ptrs(batch_size, nullptr); - std::vector ipiv_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** B_dev_ptrs = sycl::malloc_shared(batch_size, queue); + int64_t** ipiv_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -132,10 +132,10 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, - trans_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), - ldb_vec.data(), group_count, group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, + trans_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), + group_count, group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -163,14 +163,14 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { #ifdef CALL_RT_API oneapi::mkl::lapack::getrs_batch(queue, trans_vec.data(), n_vec.data(), nrhs_vec.data(), - A_dev_ptrs.data(), lda_vec.data(), ipiv_dev_ptrs.data(), - B_dev_ptrs.data(), ldb_vec.data(), group_count, - group_sizes_vec.data(), scratchpad_dev, scratchpad_size); + A_dev_ptrs, lda_vec.data(), ipiv_dev_ptrs, B_dev_ptrs, + ldb_vec.data(), group_count, group_sizes_vec.data(), + scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::getrs_batch, trans_vec.data(), n_vec.data(), - nrhs_vec.data(), A_dev_ptrs.data(), lda_vec.data(), ipiv_dev_ptrs.data(), - B_dev_ptrs.data(), ldb_vec.data(), group_count, group_sizes_vec.data(), - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::getrs_batch, trans_vec.data(), + n_vec.data(), nrhs_vec.data(), A_dev_ptrs, lda_vec.data(), + ipiv_dev_ptrs, B_dev_ptrs, ldb_vec.data(), group_count, + group_sizes_vec.data(), scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -179,6 +179,18 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { device_to_host_copy(queue, B_dev_ptrs[global_id], B_iter->data(), B_iter->size()); } queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (B_dev_ptrs) { + sycl::free(B_dev_ptrs, queue); + } + if (ipiv_dev_ptrs) { + sycl::free(ipiv_dev_ptrs, queue); + } } bool result = true; @@ -280,9 +292,9 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { std::list>> B_dev_list; std::list>> ipiv_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector B_dev_ptrs(batch_size, nullptr); - std::vector ipiv_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** B_dev_ptrs = sycl::malloc_shared(batch_size, queue); + int64_t** ipiv_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -306,10 +318,10 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, - trans_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), - ldb_vec.data(), group_count, group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, + trans_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), + group_count, group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -339,21 +351,32 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { auto in_event = create_dependency(queue); #ifdef CALL_RT_API sycl::event func_event = oneapi::mkl::lapack::getrs_batch( - queue, trans_vec.data(), n_vec.data(), nrhs_vec.data(), A_dev_ptrs.data(), - lda_vec.data(), ipiv_dev_ptrs.data(), B_dev_ptrs.data(), ldb_vec.data(), group_count, - group_sizes_vec.data(), scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + queue, trans_vec.data(), n_vec.data(), nrhs_vec.data(), A_dev_ptrs, lda_vec.data(), + ipiv_dev_ptrs, B_dev_ptrs, ldb_vec.data(), group_count, group_sizes_vec.data(), + scratchpad_dev, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrs_batch, trans_vec.data(), - n_vec.data(), nrhs_vec.data(), A_dev_ptrs.data(), lda_vec.data(), - ipiv_dev_ptrs.data(), B_dev_ptrs.data(), ldb_vec.data(), group_count, - group_sizes_vec.data(), scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrs_batch, + trans_vec.data(), n_vec.data(), nrhs_vec.data(), A_dev_ptrs, + lda_vec.data(), ipiv_dev_ptrs, B_dev_ptrs, ldb_vec.data(), + group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (B_dev_ptrs) { + sycl::free(B_dev_ptrs, queue); + } + if (ipiv_dev_ptrs) { + sycl::free(ipiv_dev_ptrs, queue); + } } return result; diff --git a/tests/unit_tests/lapack/source/getrs_batch_stride.cpp b/tests/unit_tests/lapack/source/getrs_batch_stride.cpp index 92583e73a..1faf3d3e6 100644 --- a/tests/unit_tests/lapack/source/getrs_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/getrs_batch_stride.cpp @@ -78,9 +78,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::transpose trans, int64_t n, queue, trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, - trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, trans, n, + nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -94,9 +94,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::transpose trans, int64_t n, stride_ipiv, B_dev, ldb, stride_b, batch_size, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::getrs_batch, trans, n, nrhs, A_dev, lda, - stride_a, ipiv_dev, stride_ipiv, B_dev, ldb, stride_b, batch_size, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::getrs_batch, trans, n, nrhs, A_dev, + lda, stride_a, ipiv_dev, stride_ipiv, B_dev, ldb, stride_b, + batch_size, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -167,9 +167,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 queue, trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, - trans, n, nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::getrs_batch_scratchpad_size, trans, n, + nrhs, lda, stride_a, stride_ipiv, ldb, stride_b, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -187,10 +187,10 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::transpose trans, int64 std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrs_batch, trans, n, nrhs, - A_dev, lda, stride_a, ipiv_dev, stride_ipiv, B_dev, ldb, stride_b, - batch_size, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::getrs_batch, trans, n, + nrhs, A_dev, lda, stride_a, ipiv_dev, stride_ipiv, B_dev, ldb, + stride_b, batch_size, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/heevd.cpp b/tests/unit_tests/lapack/source/heevd.cpp index 3d53e6449..62c23c3ad 100644 --- a/tests/unit_tests/lapack/source/heevd.cpp +++ b/tests/unit_tests/lapack/source/heevd.cpp @@ -62,8 +62,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl::uplo oneapi::mkl::lapack::heevd_scratchpad_size(queue, jobz, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::heevd_scratchpad_size, - jobz, uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::heevd_scratchpad_size, + jobz, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -74,8 +75,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl::uplo oneapi::mkl::lapack::heevd(queue, jobz, uplo, n, A_dev, lda, w_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::heevd, jobz, uplo, n, A_dev, lda, w_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::heevd, jobz, uplo, n, A_dev, lda, + w_dev, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -119,8 +120,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: oneapi::mkl::lapack::heevd_scratchpad_size(queue, jobz, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::heevd_scratchpad_size, - jobz, uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::heevd_scratchpad_size, + jobz, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -135,9 +137,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::heevd, jobz, uplo, n, A_dev, - lda, w_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::heevd, jobz, uplo, n, + A_dev, lda, w_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/hegvd.cpp b/tests/unit_tests/lapack/source/hegvd.cpp index 08932a943..9a109e6b8 100644 --- a/tests/unit_tests/lapack/source/hegvd.cpp +++ b/tests/unit_tests/lapack/source/hegvd.cpp @@ -68,8 +68,9 @@ bool accuracy(const sycl::device& dev, int64_t itype, oneapi::mkl::job jobz, one oneapi::mkl::lapack::hegvd_scratchpad_size(queue, itype, jobz, uplo, n, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::hegvd_scratchpad_size, - itype, jobz, uplo, n, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::hegvd_scratchpad_size, + itype, jobz, uplo, n, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -81,8 +82,8 @@ bool accuracy(const sycl::device& dev, int64_t itype, oneapi::mkl::job jobz, one oneapi::mkl::lapack::hegvd(queue, itype, jobz, uplo, n, A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::hegvd, itype, jobz, uplo, n, A_dev, lda, - B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::hegvd, itype, jobz, uplo, n, A_dev, + lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -255,8 +256,9 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job oneapi::mkl::lapack::hegvd_scratchpad_size(queue, itype, jobz, uplo, n, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::hegvd_scratchpad_size, - itype, jobz, uplo, n, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::hegvd_scratchpad_size, + itype, jobz, uplo, n, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -272,9 +274,9 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::hegvd, itype, jobz, uplo, n, - A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::hegvd, itype, jobz, uplo, + n, A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/hetrd.cpp b/tests/unit_tests/lapack/source/hetrd.cpp index 00f7b9f7f..13172d64f 100644 --- a/tests/unit_tests/lapack/source/hetrd.cpp +++ b/tests/unit_tests/lapack/source/hetrd.cpp @@ -36,7 +36,7 @@ namespace { const char* accuracy_input = R"( -0 33 35 27182 +1 33 35 27182 )"; template @@ -66,8 +66,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::hetrd_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::hetrd_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::hetrd_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -81,8 +81,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::hetrd(queue, uplo, n, A_dev, lda, d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::hetrd, uplo, n, A_dev, lda, d_dev, e_dev, - tau_dev, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::hetrd, uplo, n, A_dev, lda, d_dev, + e_dev, tau_dev, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -132,7 +132,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ e[diag] -= A[diag + (diag + 1) * lda].real(); else for (int64_t diag = 0; diag < n - 1; diag++) - e[diag] -= A[diag + 1 + (diag)*ldt].real(); + e[diag] -= A[diag + 1 + (diag)*lda].real(); auto ulp = reference::lamch('P'); if (reference::lange('I', n, 1, d.data(), n) > 10.0 * ulp) { @@ -179,8 +179,8 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, oneapi::mkl::lapack::hetrd_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::hetrd_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::hetrd_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -198,9 +198,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::hetrd, uplo, n, A_dev, lda, - d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::hetrd, uplo, n, A_dev, + lda, d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/hetrf.cpp b/tests/unit_tests/lapack/source/hetrf.cpp index 0fae16ffd..73535a77f 100644 --- a/tests/unit_tests/lapack/source/hetrf.cpp +++ b/tests/unit_tests/lapack/source/hetrf.cpp @@ -66,8 +66,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::hetrf_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::hetrf_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::hetrf_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -79,8 +79,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::hetrf(queue, uplo, n, A_dev, lda, ipiv_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::hetrf, uplo, n, A_dev, lda, ipiv_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::hetrf, uplo, n, A_dev, lda, ipiv_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -239,8 +239,8 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, oneapi::mkl::lapack::hetrf_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::hetrf_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::hetrf_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -256,9 +256,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::hetrf, uplo, n, A_dev, lda, - ipiv_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::hetrf, uplo, n, A_dev, + lda, ipiv_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/orgbr.cpp b/tests/unit_tests/lapack/source/orgbr.cpp index 5e7e7d0f5..274cafce0 100644 --- a/tests/unit_tests/lapack/source/orgbr.cpp +++ b/tests/unit_tests/lapack/source/orgbr.cpp @@ -82,8 +82,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::generate vect, int64_t m, in oneapi::mkl::lapack::orgbr_scratchpad_size(queue, vect, m, n, k, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::orgbr_scratchpad_size, - vect, m, n, k, lda); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::orgbr_scratchpad_size, + vect, m, n, k, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -95,8 +96,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::generate vect, int64_t m, in oneapi::mkl::lapack::orgbr(queue, vect, m, n, k, A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::orgbr, vect, m, n, k, A_dev, lda, tau_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::orgbr, vect, m, n, k, A_dev, lda, + tau_dev, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -156,8 +157,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t oneapi::mkl::lapack::orgbr_scratchpad_size(queue, vect, m, n, k, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::orgbr_scratchpad_size, - vect, m, n, k, lda); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::orgbr_scratchpad_size, + vect, m, n, k, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -173,9 +175,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgbr, vect, m, n, k, A_dev, - lda, tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgbr, vect, m, n, k, + A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/orgqr.cpp b/tests/unit_tests/lapack/source/orgqr.cpp index a6fbbcdbe..9d62daf5f 100644 --- a/tests/unit_tests/lapack/source/orgqr.cpp +++ b/tests/unit_tests/lapack/source/orgqr.cpp @@ -71,8 +71,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t oneapi::mkl::lapack::orgqr_scratchpad_size(queue, m, n, k, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size, - m, n, k, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size, m, n, k, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -84,8 +84,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t oneapi::mkl::lapack::orgqr(queue, m, n, k, A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::orgqr, m, n, k, A_dev, lda, tau_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::orgqr, m, n, k, A_dev, lda, tau_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -132,8 +132,8 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in oneapi::mkl::lapack::orgqr_scratchpad_size(queue, m, n, k, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size, - m, n, k, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::orgqr_scratchpad_size, m, n, k, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -149,9 +149,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgqr, m, n, k, A_dev, lda, - tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgqr, m, n, k, A_dev, + lda, tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/orgqr_batch_group.cpp b/tests/unit_tests/lapack/source/orgqr_batch_group.cpp index e0d83a161..3af796e7d 100644 --- a/tests/unit_tests/lapack/source/orgqr_batch_group.cpp +++ b/tests/unit_tests/lapack/source/orgqr_batch_group.cpp @@ -87,8 +87,8 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> tau_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector tau_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** tau_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -106,10 +106,10 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, - m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, - group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, + m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, + group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -131,13 +131,13 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { #ifdef CALL_RT_API oneapi::mkl::lapack::orgqr_batch(queue, m_vec.data(), n_vec.data(), k_vec.data(), - A_dev_ptrs.data(), lda_vec.data(), tau_dev_ptrs.data(), - group_count, group_sizes_vec.data(), scratchpad_dev, - scratchpad_size); + A_dev_ptrs, lda_vec.data(), tau_dev_ptrs, group_count, + group_sizes_vec.data(), scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::orgqr_batch, m_vec.data(), n_vec.data(), - k_vec.data(), A_dev_ptrs.data(), lda_vec.data(), tau_dev_ptrs.data(), - group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::orgqr_batch, m_vec.data(), + n_vec.data(), k_vec.data(), A_dev_ptrs, lda_vec.data(), + tau_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -146,6 +146,15 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { device_to_host_copy(queue, A_dev_ptrs[global_id], A_iter->data(), A_iter->size()); } queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (tau_dev_ptrs) { + sycl::free(tau_dev_ptrs, queue); + } } bool result = true; @@ -223,8 +232,8 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> tau_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector tau_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** tau_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -242,10 +251,10 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, - m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, - group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, + m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, + group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -269,19 +278,29 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { auto in_event = create_dependency(queue); #ifdef CALL_RT_API sycl::event func_event = oneapi::mkl::lapack::orgqr_batch( - queue, m_vec.data(), n_vec.data(), k_vec.data(), A_dev_ptrs.data(), lda_vec.data(), - tau_dev_ptrs.data(), group_count, group_sizes_vec.data(), scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + queue, m_vec.data(), n_vec.data(), k_vec.data(), A_dev_ptrs, lda_vec.data(), + tau_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgqr_batch, m_vec.data(), - n_vec.data(), k_vec.data(), A_dev_ptrs.data(), lda_vec.data(), - tau_dev_ptrs.data(), group_count, group_sizes_vec.data(), scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgqr_batch, + m_vec.data(), n_vec.data(), k_vec.data(), A_dev_ptrs, + lda_vec.data(), tau_dev_ptrs, group_count, group_sizes_vec.data(), + scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (tau_dev_ptrs) { + sycl::free(tau_dev_ptrs, queue); + } } return result; diff --git a/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp b/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp index 62028bc41..1cf3471c5 100644 --- a/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/orgqr_batch_stride.cpp @@ -71,9 +71,9 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t queue, m, n, k, lda, stride_a, stride_tau, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, - m, n, k, lda, stride_a, stride_tau, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, m, n, k, + lda, stride_a, stride_tau, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -85,8 +85,9 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t oneapi::mkl::lapack::orgqr_batch(queue, m, n, k, A_dev, lda, stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::orgqr_batch, m, n, k, A_dev, lda, stride_a, - tau_dev, stride_tau, batch_size, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::orgqr_batch, m, n, k, A_dev, lda, + stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -148,9 +149,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in queue, m, n, k, lda, stride_a, stride_tau, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, - m, n, k, lda, stride_a, stride_tau, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::orgqr_batch_scratchpad_size, m, n, k, + lda, stride_a, stride_tau, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -166,9 +167,10 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgqr_batch, m, n, k, A_dev, - lda, stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgqr_batch, m, n, k, + A_dev, lda, stride_a, tau_dev, stride_tau, batch_size, + scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/orgtr.cpp b/tests/unit_tests/lapack/source/orgtr.cpp index 1e91778e1..5a01745d5 100644 --- a/tests/unit_tests/lapack/source/orgtr.cpp +++ b/tests/unit_tests/lapack/source/orgtr.cpp @@ -69,8 +69,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::orgtr_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::orgtr_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::orgtr_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -82,8 +82,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::orgtr(queue, uplo, n, A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::orgtr, uplo, n, A_dev, lda, tau_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::orgtr, uplo, n, A_dev, lda, tau_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -133,8 +133,8 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, oneapi::mkl::lapack::orgtr_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::orgtr_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::orgtr_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -150,9 +150,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgtr, uplo, n, A_dev, lda, - tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::orgtr, uplo, n, A_dev, + lda, tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/ormqr.cpp b/tests/unit_tests/lapack/source/ormqr.cpp index e40f990ab..e2ed49b96 100644 --- a/tests/unit_tests/lapack/source/ormqr.cpp +++ b/tests/unit_tests/lapack/source/ormqr.cpp @@ -79,8 +79,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl queue, left_right, trans, m, n, k, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size, - left_right, trans, m, n, k, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size, + left_right, trans, m, n, k, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -93,8 +94,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl oneapi::mkl::lapack::ormqr(queue, left_right, trans, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::ormqr, left_right, trans, m, n, k, A_dev, - lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::ormqr, left_right, trans, m, n, k, + A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -165,8 +166,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, queue, left_right, trans, m, n, k, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size, - left_right, trans, m, n, k, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::ormqr_scratchpad_size, + left_right, trans, m, n, k, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -183,9 +185,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ormqr, left_right, trans, m, n, - k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ormqr, left_right, trans, + m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/ormrq.cpp b/tests/unit_tests/lapack/source/ormrq.cpp index be02df2f6..4882e5bc7 100644 --- a/tests/unit_tests/lapack/source/ormrq.cpp +++ b/tests/unit_tests/lapack/source/ormrq.cpp @@ -89,8 +89,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl queue, left_right, trans, m, n, k, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ormrq_scratchpad_size, - left_right, trans, m, n, k, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::ormrq_scratchpad_size, + left_right, trans, m, n, k, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -103,8 +104,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl oneapi::mkl::lapack::ormrq(queue, left_right, trans, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::ormrq, left_right, trans, m, n, k, A_dev, - lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::ormrq, left_right, trans, m, n, k, + A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -174,8 +175,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, queue, left_right, trans, m, n, k, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ormrq_scratchpad_size, - left_right, trans, m, n, k, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::ormrq_scratchpad_size, + left_right, trans, m, n, k, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -192,9 +194,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ormrq, left_right, trans, m, n, - k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ormrq, left_right, trans, + m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/ormtr.cpp b/tests/unit_tests/lapack/source/ormtr.cpp index 892786444..4e8dd95b9 100644 --- a/tests/unit_tests/lapack/source/ormtr.cpp +++ b/tests/unit_tests/lapack/source/ormtr.cpp @@ -77,8 +77,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, int64_ queue, side, uplo, trans, m, n, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ormtr_scratchpad_size, - side, uplo, trans, m, n, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::ormtr_scratchpad_size, + side, uplo, trans, m, n, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -91,8 +92,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, int64_ oneapi::mkl::lapack::ormtr(queue, side, uplo, trans, m, n, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::ormtr, side, uplo, trans, m, n, A_dev, lda, - tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::ormtr, side, uplo, trans, m, n, A_dev, + lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -164,8 +165,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, queue, side, uplo, trans, m, n, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ormtr_scratchpad_size, - side, uplo, trans, m, n, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::ormtr_scratchpad_size, + side, uplo, trans, m, n, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -182,9 +184,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ormtr, side, uplo, trans, m, n, - A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ormtr, side, uplo, trans, + m, n, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/potrf.cpp b/tests/unit_tests/lapack/source/potrf.cpp index fae819c6e..7d2df8ea9 100644 --- a/tests/unit_tests/lapack/source/potrf.cpp +++ b/tests/unit_tests/lapack/source/potrf.cpp @@ -64,8 +64,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::potrf_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::potrf_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potrf_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -75,8 +75,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ #ifdef CALL_RT_API oneapi::mkl::lapack::potrf(queue, uplo, n, A_dev, lda, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::potrf, uplo, n, A_dev, lda, scratchpad_dev, - scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::potrf, uplo, n, A_dev, lda, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -117,8 +117,8 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, oneapi::mkl::lapack::potrf_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::potrf_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potrf_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -133,8 +133,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrf, uplo, n, A_dev, lda, - scratchpad_dev, scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrf, uplo, n, A_dev, + lda, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/potrf_batch_group.cpp b/tests/unit_tests/lapack/source/potrf_batch_group.cpp index e9e73cb63..4a5b8dd58 100644 --- a/tests/unit_tests/lapack/source/potrf_batch_group.cpp +++ b/tests/unit_tests/lapack/source/potrf_batch_group.cpp @@ -79,7 +79,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { sycl::queue queue{ dev, async_error_handler }; std::list>> A_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -95,7 +95,7 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT( + TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif @@ -113,13 +113,13 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { queue.wait_and_throw(); #ifdef CALL_RT_API - oneapi::mkl::lapack::potrf_batch(queue, uplo_vec.data(), n_vec.data(), A_dev_ptrs.data(), + oneapi::mkl::lapack::potrf_batch(queue, uplo_vec.data(), n_vec.data(), A_dev_ptrs, lda_vec.data(), group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::potrf_batch, uplo_vec.data(), n_vec.data(), - A_dev_ptrs.data(), lda_vec.data(), group_count, group_sizes_vec.data(), - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::potrf_batch, uplo_vec.data(), + n_vec.data(), A_dev_ptrs, lda_vec.data(), group_count, + group_sizes_vec.data(), scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -128,6 +128,12 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { device_to_host_copy(queue, A_dev_ptrs[global_id], A_iter->data(), A_iter->size()); } queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } } bool result = true; @@ -196,7 +202,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { sycl::queue queue{ dev, async_error_handler }; std::list>> A_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -212,7 +218,7 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT( + TEST_RUN_LAPACK_CT_SELECT( queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo_vec.data(), n_vec.data(), lda_vec.data(), group_count, group_sizes_vec.data()); #endif @@ -233,19 +239,25 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { auto in_event = create_dependency(queue); #ifdef CALL_RT_API sycl::event func_event = oneapi::mkl::lapack::potrf_batch( - queue, uplo_vec.data(), n_vec.data(), A_dev_ptrs.data(), lda_vec.data(), group_count, + queue, uplo_vec.data(), n_vec.data(), A_dev_ptrs, lda_vec.data(), group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrf_batch, uplo_vec.data(), - n_vec.data(), A_dev_ptrs.data(), lda_vec.data(), group_count, - group_sizes_vec.data(), scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrf_batch, + uplo_vec.data(), n_vec.data(), A_dev_ptrs, lda_vec.data(), + group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } } return result; diff --git a/tests/unit_tests/lapack/source/potrf_batch_stride.cpp b/tests/unit_tests/lapack/source/potrf_batch_stride.cpp index 50bd9a15f..fae4f0bcc 100644 --- a/tests/unit_tests/lapack/source/potrf_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/potrf_batch_stride.cpp @@ -62,9 +62,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ queue, uplo, n, lda, stride_a, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, - uplo, n, lda, stride_a, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo, n, + lda, stride_a, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -75,8 +75,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::potrf_batch(queue, uplo, n, A_dev, lda, stride_a, batch_size, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::potrf_batch, uplo, n, A_dev, lda, stride_a, - batch_size, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::potrf_batch, uplo, n, A_dev, lda, + stride_a, batch_size, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -128,9 +128,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, queue, uplo, n, lda, stride_a, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, - uplo, n, lda, stride_a, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potrf_batch_scratchpad_size, uplo, n, + lda, stride_a, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -145,9 +145,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrf_batch, uplo, n, A_dev, - lda, stride_a, batch_size, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrf_batch, uplo, n, + A_dev, lda, stride_a, batch_size, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/potri.cpp b/tests/unit_tests/lapack/source/potri.cpp index e563e6514..cd2f86449 100644 --- a/tests/unit_tests/lapack/source/potri.cpp +++ b/tests/unit_tests/lapack/source/potri.cpp @@ -68,8 +68,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::potri_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::potri_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potri_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -79,8 +79,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ #ifdef CALL_RT_API oneapi::mkl::lapack::potri(queue, uplo, n, A_dev, lda, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::potri, uplo, n, A_dev, lda, scratchpad_dev, - scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::potri, uplo, n, A_dev, lda, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -151,8 +151,8 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, oneapi::mkl::lapack::potri_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::potri_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potri_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -167,8 +167,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potri, uplo, n, A_dev, lda, - scratchpad_dev, scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potri, uplo, n, A_dev, + lda, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/potrs.cpp b/tests/unit_tests/lapack/source/potrs.cpp index 6d5fddbb0..c534ec8ba 100644 --- a/tests/unit_tests/lapack/source/potrs.cpp +++ b/tests/unit_tests/lapack/source/potrs.cpp @@ -71,8 +71,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::potrs_scratchpad_size(queue, uplo, n, nrhs, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::potrs_scratchpad_size, - uplo, n, nrhs, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::potrs_scratchpad_size, + uplo, n, nrhs, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -84,8 +85,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::potrs(queue, uplo, n, nrhs, A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::potrs, uplo, n, nrhs, A_dev, lda, B_dev, ldb, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::potrs, uplo, n, nrhs, A_dev, lda, + B_dev, ldb, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -137,8 +138,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, oneapi::mkl::lapack::potrs_scratchpad_size(queue, uplo, n, nrhs, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::potrs_scratchpad_size, - uplo, n, nrhs, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::potrs_scratchpad_size, + uplo, n, nrhs, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -154,9 +156,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrs, uplo, n, nrhs, A_dev, - lda, B_dev, ldb, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrs, uplo, n, nrhs, + A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/potrs_batch_group.cpp b/tests/unit_tests/lapack/source/potrs_batch_group.cpp index 886fe8d77..35c5ead0c 100644 --- a/tests/unit_tests/lapack/source/potrs_batch_group.cpp +++ b/tests/unit_tests/lapack/source/potrs_batch_group.cpp @@ -100,8 +100,8 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> B_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector B_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** B_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -119,10 +119,10 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, - uplo_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), - ldb_vec.data(), group_count, group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, + uplo_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), + group_count, group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -144,14 +144,14 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { #ifdef CALL_RT_API oneapi::mkl::lapack::potrs_batch(queue, uplo_vec.data(), n_vec.data(), nrhs_vec.data(), - A_dev_ptrs.data(), lda_vec.data(), B_dev_ptrs.data(), - ldb_vec.data(), group_count, group_sizes_vec.data(), - scratchpad_dev, scratchpad_size); + A_dev_ptrs, lda_vec.data(), B_dev_ptrs, ldb_vec.data(), + group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::potrs_batch, uplo_vec.data(), n_vec.data(), - nrhs_vec.data(), A_dev_ptrs.data(), lda_vec.data(), B_dev_ptrs.data(), - ldb_vec.data(), group_count, group_sizes_vec.data(), scratchpad_dev, - scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::potrs_batch, uplo_vec.data(), + n_vec.data(), nrhs_vec.data(), A_dev_ptrs, lda_vec.data(), + B_dev_ptrs, ldb_vec.data(), group_count, group_sizes_vec.data(), + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -160,6 +160,15 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { device_to_host_copy(queue, B_dev_ptrs[global_id], B_iter->data(), B_iter->size()); } queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (B_dev_ptrs) { + sycl::free(B_dev_ptrs, queue); + } } bool result = true; @@ -254,8 +263,8 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> B_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector B_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** B_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -273,10 +282,10 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_count, group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, - uplo_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), - ldb_vec.data(), group_count, group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, + uplo_vec.data(), n_vec.data(), nrhs_vec.data(), lda_vec.data(), ldb_vec.data(), + group_count, group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -300,19 +309,29 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { auto in_event = create_dependency(queue); #ifdef CALL_RT_API sycl::event func_event = oneapi::mkl::lapack::potrs_batch( - queue, uplo_vec.data(), n_vec.data(), nrhs_vec.data(), A_dev_ptrs.data(), - lda_vec.data(), B_dev_ptrs.data(), ldb_vec.data(), group_count, group_sizes_vec.data(), - scratchpad_dev, scratchpad_size, std::vector{ in_event }); + queue, uplo_vec.data(), n_vec.data(), nrhs_vec.data(), A_dev_ptrs, lda_vec.data(), + B_dev_ptrs, ldb_vec.data(), group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrs_batch, uplo_vec.data(), - n_vec.data(), nrhs_vec.data(), A_dev_ptrs.data(), lda_vec.data(), - B_dev_ptrs.data(), ldb_vec.data(), group_count, group_sizes_vec.data(), - scratchpad_dev, scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrs_batch, + uplo_vec.data(), n_vec.data(), nrhs_vec.data(), A_dev_ptrs, + lda_vec.data(), B_dev_ptrs, ldb_vec.data(), group_count, + group_sizes_vec.data(), scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (B_dev_ptrs) { + sycl::free(B_dev_ptrs, queue); + } } return result; diff --git a/tests/unit_tests/lapack/source/potrs_batch_stride.cpp b/tests/unit_tests/lapack/source/potrs_batch_stride.cpp index dac8488e8..de2568e86 100644 --- a/tests/unit_tests/lapack/source/potrs_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/potrs_batch_stride.cpp @@ -76,9 +76,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ queue, uplo, n, nrhs, lda, stride_a, ldb, stride_b, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, - uplo, n, nrhs, lda, stride_a, ldb, stride_b, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, uplo, n, + nrhs, lda, stride_a, ldb, stride_b, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -90,9 +90,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::potrs_batch(queue, uplo, n, nrhs, A_dev, lda, stride_a, B_dev, ldb, stride_b, batch_size, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::potrs_batch, uplo, n, nrhs, A_dev, lda, - stride_a, B_dev, ldb, stride_b, batch_size, scratchpad_dev, - scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::potrs_batch, uplo, n, nrhs, A_dev, + lda, stride_a, B_dev, ldb, stride_b, batch_size, scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -161,9 +161,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, queue, uplo, n, nrhs, lda, stride_a, ldb, stride_b, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, - uplo, n, nrhs, lda, stride_a, ldb, stride_b, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::potrs_batch_scratchpad_size, uplo, n, + nrhs, lda, stride_a, ldb, stride_b, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -179,9 +179,10 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_dev, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrs_batch, uplo, n, nrhs, - A_dev, lda, stride_a, B_dev, ldb, stride_b, batch_size, scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::potrs_batch, uplo, n, + nrhs, A_dev, lda, stride_a, B_dev, ldb, stride_b, batch_size, + scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/syevd.cpp b/tests/unit_tests/lapack/source/syevd.cpp index ab01ef23d..291713354 100644 --- a/tests/unit_tests/lapack/source/syevd.cpp +++ b/tests/unit_tests/lapack/source/syevd.cpp @@ -62,8 +62,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl::uplo oneapi::mkl::lapack::syevd_scratchpad_size(queue, jobz, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::syevd_scratchpad_size, - jobz, uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::syevd_scratchpad_size, + jobz, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -74,8 +75,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl::uplo oneapi::mkl::lapack::syevd(queue, jobz, uplo, n, A_dev, lda, w_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::syevd, jobz, uplo, n, A_dev, lda, w_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::syevd, jobz, uplo, n, A_dev, lda, + w_dev, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -119,8 +120,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: oneapi::mkl::lapack::syevd_scratchpad_size(queue, jobz, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::syevd_scratchpad_size, - jobz, uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::syevd_scratchpad_size, + jobz, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -135,9 +137,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::job jobz, oneapi::mkl: scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::syevd, jobz, uplo, n, A_dev, - lda, w_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::syevd, jobz, uplo, n, + A_dev, lda, w_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/sygvd.cpp b/tests/unit_tests/lapack/source/sygvd.cpp index 8ad5351e1..f800b03dd 100644 --- a/tests/unit_tests/lapack/source/sygvd.cpp +++ b/tests/unit_tests/lapack/source/sygvd.cpp @@ -68,8 +68,9 @@ bool accuracy(const sycl::device& dev, int64_t itype, oneapi::mkl::job jobz, one oneapi::mkl::lapack::sygvd_scratchpad_size(queue, itype, jobz, uplo, n, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::sygvd_scratchpad_size, - itype, jobz, uplo, n, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::sygvd_scratchpad_size, + itype, jobz, uplo, n, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -81,8 +82,8 @@ bool accuracy(const sycl::device& dev, int64_t itype, oneapi::mkl::job jobz, one oneapi::mkl::lapack::sygvd(queue, itype, jobz, uplo, n, A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::sygvd, itype, jobz, uplo, n, A_dev, lda, - B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::sygvd, itype, jobz, uplo, n, A_dev, + lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -260,8 +261,9 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job oneapi::mkl::lapack::sygvd_scratchpad_size(queue, itype, jobz, uplo, n, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::sygvd_scratchpad_size, - itype, jobz, uplo, n, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::sygvd_scratchpad_size, + itype, jobz, uplo, n, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -277,9 +279,9 @@ bool usm_dependency(const sycl::device& dev, int64_t itype, oneapi::mkl::job job scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::sygvd, itype, jobz, uplo, n, - A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::sygvd, itype, jobz, uplo, + n, A_dev, lda, B_dev, ldb, w_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/sytrd.cpp b/tests/unit_tests/lapack/source/sytrd.cpp index 2009cb1d1..01ffe0dff 100644 --- a/tests/unit_tests/lapack/source/sytrd.cpp +++ b/tests/unit_tests/lapack/source/sytrd.cpp @@ -36,7 +36,7 @@ namespace { const char* accuracy_input = R"( -0 33 35 27182 +1 33 35 27182 )"; template @@ -66,8 +66,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::sytrd_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::sytrd_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::sytrd_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -81,8 +81,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::sytrd(queue, uplo, n, A_dev, lda, d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::sytrd, uplo, n, A_dev, lda, d_dev, e_dev, - tau_dev, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::sytrd, uplo, n, A_dev, lda, d_dev, + e_dev, tau_dev, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -132,7 +132,7 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ e[diag] -= A[diag + (diag + 1) * lda]; else for (int64_t diag = 0; diag < n - 1; diag++) - e[diag] -= A[diag + 1 + (diag)*ldt]; + e[diag] -= A[diag + 1 + (diag)*lda]; auto ulp = reference::lamch('P'); if (reference::lange('I', n, 1, d.data(), n) > 10.0 * ulp) { @@ -179,8 +179,8 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, oneapi::mkl::lapack::sytrd_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::sytrd_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::sytrd_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -198,9 +198,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::sytrd, uplo, n, A_dev, lda, - d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::sytrd, uplo, n, A_dev, + lda, d_dev, e_dev, tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/sytrf.cpp b/tests/unit_tests/lapack/source/sytrf.cpp index b1f850776..81d7fdb2d 100644 --- a/tests/unit_tests/lapack/source/sytrf.cpp +++ b/tests/unit_tests/lapack/source/sytrf.cpp @@ -63,8 +63,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::sytrf_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::sytrf_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::sytrf_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -76,8 +76,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::sytrf(queue, uplo, n, A_dev, lda, ipiv_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::sytrf, uplo, n, A_dev, lda, ipiv_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::sytrf, uplo, n, A_dev, lda, ipiv_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -236,8 +236,8 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, oneapi::mkl::lapack::sytrf_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::sytrf_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::sytrf_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -253,9 +253,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::sytrf, uplo, n, A_dev, lda, - ipiv_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::sytrf, uplo, n, A_dev, + lda, ipiv_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/trtrs.cpp b/tests/unit_tests/lapack/source/trtrs.cpp index 7766b90c5..4018a2c51 100644 --- a/tests/unit_tests/lapack/source/trtrs.cpp +++ b/tests/unit_tests/lapack/source/trtrs.cpp @@ -74,8 +74,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl::tran queue, uplo, trans, diag, n, nrhs, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::trtrs_scratchpad_size, - uplo, trans, diag, n, nrhs, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::trtrs_scratchpad_size, + uplo, trans, diag, n, nrhs, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -87,8 +88,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl::tran oneapi::mkl::lapack::trtrs(queue, uplo, trans, diag, n, nrhs, A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::trtrs, uplo, trans, diag, n, nrhs, A_dev, - lda, B_dev, ldb, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::trtrs, uplo, trans, diag, n, nrhs, + A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -136,8 +137,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl queue, uplo, trans, diag, n, nrhs, lda, ldb); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::trtrs_scratchpad_size, - uplo, trans, diag, n, nrhs, lda, ldb); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::trtrs_scratchpad_size, + uplo, trans, diag, n, nrhs, lda, ldb); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -153,9 +155,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, oneapi::mkl scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::trtrs, uplo, trans, diag, n, - nrhs, A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::trtrs, uplo, trans, diag, + n, nrhs, A_dev, lda, B_dev, ldb, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/ungbr.cpp b/tests/unit_tests/lapack/source/ungbr.cpp index 3b0b96c45..7cdf8e52a 100644 --- a/tests/unit_tests/lapack/source/ungbr.cpp +++ b/tests/unit_tests/lapack/source/ungbr.cpp @@ -82,8 +82,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::generate vect, int64_t m, in oneapi::mkl::lapack::ungbr_scratchpad_size(queue, vect, m, n, k, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ungbr_scratchpad_size, - vect, m, n, k, lda); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::ungbr_scratchpad_size, + vect, m, n, k, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -95,8 +96,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::generate vect, int64_t m, in oneapi::mkl::lapack::ungbr(queue, vect, m, n, k, A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::ungbr, vect, m, n, k, A_dev, lda, tau_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::ungbr, vect, m, n, k, A_dev, lda, + tau_dev, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -156,8 +157,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t oneapi::mkl::lapack::ungbr_scratchpad_size(queue, vect, m, n, k, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ungbr_scratchpad_size, - vect, m, n, k, lda); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::ungbr_scratchpad_size, + vect, m, n, k, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -173,9 +175,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::generate vect, int64_t scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungbr, vect, m, n, k, A_dev, - lda, tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungbr, vect, m, n, k, + A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/ungqr.cpp b/tests/unit_tests/lapack/source/ungqr.cpp index 5576098fc..08b8b1192 100644 --- a/tests/unit_tests/lapack/source/ungqr.cpp +++ b/tests/unit_tests/lapack/source/ungqr.cpp @@ -70,8 +70,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t oneapi::mkl::lapack::ungqr_scratchpad_size(queue, m, n, k, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size, - m, n, k, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size, m, n, k, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -83,8 +83,8 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t oneapi::mkl::lapack::ungqr(queue, m, n, k, A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::ungqr, m, n, k, A_dev, lda, tau_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::ungqr, m, n, k, A_dev, lda, tau_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -131,8 +131,8 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in oneapi::mkl::lapack::ungqr_scratchpad_size(queue, m, n, k, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size, - m, n, k, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::ungqr_scratchpad_size, m, n, k, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -148,9 +148,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungqr, m, n, k, A_dev, lda, - tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungqr, m, n, k, A_dev, + lda, tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/ungqr_batch_group.cpp b/tests/unit_tests/lapack/source/ungqr_batch_group.cpp index 71f6bc927..ddb350828 100644 --- a/tests/unit_tests/lapack/source/ungqr_batch_group.cpp +++ b/tests/unit_tests/lapack/source/ungqr_batch_group.cpp @@ -87,8 +87,8 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> tau_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector tau_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** tau_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -106,10 +106,10 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, - m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, - group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, + m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, + group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -131,13 +131,13 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { #ifdef CALL_RT_API oneapi::mkl::lapack::ungqr_batch(queue, m_vec.data(), n_vec.data(), k_vec.data(), - A_dev_ptrs.data(), lda_vec.data(), tau_dev_ptrs.data(), - group_count, group_sizes_vec.data(), scratchpad_dev, - scratchpad_size); + A_dev_ptrs, lda_vec.data(), tau_dev_ptrs, group_count, + group_sizes_vec.data(), scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::ungqr_batch, m_vec.data(), n_vec.data(), - k_vec.data(), A_dev_ptrs.data(), lda_vec.data(), tau_dev_ptrs.data(), - group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::ungqr_batch, m_vec.data(), + n_vec.data(), k_vec.data(), A_dev_ptrs, lda_vec.data(), + tau_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -146,6 +146,15 @@ bool accuracy(const sycl::device& dev, uint64_t seed) { device_to_host_copy(queue, A_dev_ptrs[global_id], A_iter->data(), A_iter->size()); } queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (tau_dev_ptrs) { + sycl::free(tau_dev_ptrs, queue); + } } bool result = true; @@ -223,8 +232,8 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { std::list>> A_dev_list; std::list>> tau_dev_list; - std::vector A_dev_ptrs(batch_size, nullptr); - std::vector tau_dev_ptrs(batch_size, nullptr); + fp** A_dev_ptrs = sycl::malloc_shared(batch_size, queue); + fp** tau_dev_ptrs = sycl::malloc_shared(batch_size, queue); /* Allocate on device */ sycl::usm_allocator usm_fp_allocator{ queue.get_context(), @@ -242,10 +251,10 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { group_sizes_vec.data()); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, - m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, - group_sizes_vec.data()); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, + m_vec.data(), n_vec.data(), k_vec.data(), lda_vec.data(), group_count, + group_sizes_vec.data()); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -269,19 +278,29 @@ bool usm_dependency(const sycl::device& dev, uint64_t seed) { auto in_event = create_dependency(queue); #ifdef CALL_RT_API sycl::event func_event = oneapi::mkl::lapack::ungqr_batch( - queue, m_vec.data(), n_vec.data(), k_vec.data(), A_dev_ptrs.data(), lda_vec.data(), - tau_dev_ptrs.data(), group_count, group_sizes_vec.data(), scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + queue, m_vec.data(), n_vec.data(), k_vec.data(), A_dev_ptrs, lda_vec.data(), + tau_dev_ptrs, group_count, group_sizes_vec.data(), scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungqr_batch, m_vec.data(), - n_vec.data(), k_vec.data(), A_dev_ptrs.data(), lda_vec.data(), - tau_dev_ptrs.data(), group_count, group_sizes_vec.data(), scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungqr_batch, + m_vec.data(), n_vec.data(), k_vec.data(), A_dev_ptrs, + lda_vec.data(), tau_dev_ptrs, group_count, group_sizes_vec.data(), + scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); + if (scratchpad_dev) { + sycl::free(scratchpad_dev, queue); + } + if (A_dev_ptrs) { + sycl::free(A_dev_ptrs, queue); + } + if (tau_dev_ptrs) { + sycl::free(tau_dev_ptrs, queue); + } } return result; diff --git a/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp b/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp index 55c849de2..e656b9fb7 100644 --- a/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp +++ b/tests/unit_tests/lapack/source/ungqr_batch_stride.cpp @@ -71,9 +71,9 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t queue, m, n, k, lda, stride_a, stride_tau, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, - m, n, k, lda, stride_a, stride_tau, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, m, n, k, + lda, stride_a, stride_tau, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -85,8 +85,9 @@ bool accuracy(const sycl::device& dev, int64_t m, int64_t n, int64_t k, int64_t oneapi::mkl::lapack::ungqr_batch(queue, m, n, k, A_dev, lda, stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::ungqr_batch, m, n, k, A_dev, lda, stride_a, - tau_dev, stride_tau, batch_size, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::ungqr_batch, m, n, k, A_dev, lda, + stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, + scratchpad_size); #endif queue.wait_and_throw(); @@ -148,9 +149,9 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in queue, m, n, k, lda, stride_a, stride_tau, batch_size); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, - scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, - m, n, k, lda, stride_a, stride_tau, batch_size); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::ungqr_batch_scratchpad_size, m, n, k, + lda, stride_a, stride_tau, batch_size); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -166,9 +167,10 @@ bool usm_dependency(const sycl::device& dev, int64_t m, int64_t n, int64_t k, in scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungqr_batch, m, n, k, A_dev, - lda, stride_a, tau_dev, stride_tau, batch_size, scratchpad_dev, - scratchpad_size, std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungqr_batch, m, n, k, + A_dev, lda, stride_a, tau_dev, stride_tau, batch_size, + scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/ungtr.cpp b/tests/unit_tests/lapack/source/ungtr.cpp index 0205448b2..b0ad8e8f2 100644 --- a/tests/unit_tests/lapack/source/ungtr.cpp +++ b/tests/unit_tests/lapack/source/ungtr.cpp @@ -69,8 +69,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::ungtr_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ungtr_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::ungtr_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -82,8 +82,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, int64_ oneapi::mkl::lapack::ungtr(queue, uplo, n, A_dev, lda, tau_dev, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::ungtr, uplo, n, A_dev, lda, tau_dev, - scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::ungtr, uplo, n, A_dev, lda, tau_dev, + scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -133,8 +133,8 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, oneapi::mkl::lapack::ungtr_scratchpad_size(queue, uplo, n, lda); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::ungtr_scratchpad_size, - uplo, n, lda); + TEST_RUN_LAPACK_CT_SELECT( + queue, scratchpad_size = oneapi::mkl::lapack::ungtr_scratchpad_size, uplo, n, lda); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -150,9 +150,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t n, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungtr, uplo, n, A_dev, lda, - tau_dev, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::ungtr, uplo, n, A_dev, + lda, tau_dev, scratchpad_dev, scratchpad_size, + std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/lapack/source/unmqr.cpp b/tests/unit_tests/lapack/source/unmqr.cpp index 1afd562cb..2f555c1ca 100644 --- a/tests/unit_tests/lapack/source/unmqr.cpp +++ b/tests/unit_tests/lapack/source/unmqr.cpp @@ -79,8 +79,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl queue, left_right, trans, m, n, k, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size, - left_right, trans, m, n, k, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size, + left_right, trans, m, n, k, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -93,8 +94,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl oneapi::mkl::lapack::unmqr(queue, left_right, trans, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::unmqr, left_right, trans, m, n, k, A_dev, - lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::unmqr, left_right, trans, m, n, k, + A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -165,8 +166,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, queue, left_right, trans, m, n, k, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size, - left_right, trans, m, n, k, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::unmqr_scratchpad_size, + left_right, trans, m, n, k, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -183,9 +185,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::unmqr, left_right, trans, m, n, - k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::unmqr, left_right, trans, + m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/unmrq.cpp b/tests/unit_tests/lapack/source/unmrq.cpp index 8cdcc0fe3..628063837 100644 --- a/tests/unit_tests/lapack/source/unmrq.cpp +++ b/tests/unit_tests/lapack/source/unmrq.cpp @@ -89,8 +89,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl queue, left_right, trans, m, n, k, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::unmrq_scratchpad_size, - left_right, trans, m, n, k, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::unmrq_scratchpad_size, + left_right, trans, m, n, k, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -103,8 +104,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::side left_right, oneapi::mkl oneapi::mkl::lapack::unmrq(queue, left_right, trans, m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::unmrq, left_right, trans, m, n, k, A_dev, - lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::unmrq, left_right, trans, m, n, k, + A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -174,8 +175,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, queue, left_right, trans, m, n, k, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::unmrq_scratchpad_size, - left_right, trans, m, n, k, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::unmrq_scratchpad_size, + left_right, trans, m, n, k, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -192,9 +194,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::side left_right, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::unmrq, left_right, trans, m, n, - k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::unmrq, left_right, trans, + m, n, k, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); diff --git a/tests/unit_tests/lapack/source/unmtr.cpp b/tests/unit_tests/lapack/source/unmtr.cpp index 095f1286a..8148c644d 100644 --- a/tests/unit_tests/lapack/source/unmtr.cpp +++ b/tests/unit_tests/lapack/source/unmtr.cpp @@ -77,8 +77,9 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, int64_ queue, side, uplo, trans, m, n, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::unmtr_scratchpad_size, - side, uplo, trans, m, n, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::unmtr_scratchpad_size, + side, uplo, trans, m, n, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -91,8 +92,8 @@ bool accuracy(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, int64_ oneapi::mkl::lapack::unmtr(queue, side, uplo, trans, m, n, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #else - TEST_RUN_CT_SELECT(queue, oneapi::mkl::lapack::unmtr, side, uplo, trans, m, n, A_dev, lda, - tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); + TEST_RUN_LAPACK_CT_SELECT(queue, oneapi::mkl::lapack::unmtr, side, uplo, trans, m, n, A_dev, + lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size); #endif queue.wait_and_throw(); @@ -164,8 +165,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, queue, side, uplo, trans, m, n, lda, ldc); #else int64_t scratchpad_size; - TEST_RUN_CT_SELECT(queue, scratchpad_size = oneapi::mkl::lapack::unmtr_scratchpad_size, - side, uplo, trans, m, n, lda, ldc); + TEST_RUN_LAPACK_CT_SELECT(queue, + scratchpad_size = oneapi::mkl::lapack::unmtr_scratchpad_size, + side, uplo, trans, m, n, lda, ldc); #endif auto scratchpad_dev = device_alloc(queue, scratchpad_size); @@ -182,9 +184,9 @@ bool usm_dependency(const sycl::device& dev, oneapi::mkl::uplo uplo, int64_t m, scratchpad_size, std::vector{ in_event }); #else sycl::event func_event; - TEST_RUN_CT_SELECT(queue, func_event = oneapi::mkl::lapack::unmtr, side, uplo, trans, m, n, - A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, scratchpad_size, - std::vector{ in_event }); + TEST_RUN_LAPACK_CT_SELECT(queue, func_event = oneapi::mkl::lapack::unmtr, side, uplo, trans, + m, n, A_dev, lda, tau_dev, C_dev, ldc, scratchpad_dev, + scratchpad_size, std::vector{ in_event }); #endif result = check_dependency(queue, in_event, func_event); queue.wait_and_throw(); diff --git a/tests/unit_tests/main_test.cpp b/tests/unit_tests/main_test.cpp index 5c07d3036..bac3f8c83 100644 --- a/tests/unit_tests/main_test.cpp +++ b/tests/unit_tests/main_test.cpp @@ -110,21 +110,25 @@ int main(int argc, char** argv) { unique_devices.insert(dev.get_info()); unsigned int vendor_id = static_cast( dev.get_info()); -#ifndef ENABLE_MKLCPU_BACKEND +#if !defined(ENABLE_MKLCPU_BACKEND) && !defined(ENABLE_PORTBLAS_BACKEND_INTEL_CPU) && \ + !defined(ENABLE_PORTFFT_BACKEND) if (dev.is_cpu()) continue; #endif -#ifndef ENABLE_MKLGPU_BACKEND +#if !defined(ENABLE_MKLGPU_BACKEND) && !defined(ENABLE_PORTBLAS_BACKEND_INTEL_GPU) && \ + !defined(ENABLE_PORTFFT_BACKEND) if (dev.is_gpu() && vendor_id == INTEL_ID) continue; #endif -#if !defined(ENABLE_CUBLAS_BACKEND) && !defined(ENABLE_CURAND_BACKEND) && \ - !defined(ENABLE_CUSOLVER_BACKEND) +#if !defined(ENABLE_CUBLAS_BACKEND) && !defined(ENABLE_CURAND_BACKEND) && \ + !defined(ENABLE_CUSOLVER_BACKEND) && !defined(ENABLE_PORTBLAS_BACKEND_NVIDIA_GPU) && \ + !defined(ENABLE_CUFFT_BACKEND) && !defined(ENABLE_PORTFFT_BACKEND) if (dev.is_gpu() && vendor_id == NVIDIA_ID) continue; #endif -#if !defined(ENABLE_ROCBLAS_BACKEND) && !defined(ENABLE_ROCRAND_BACKEND) && \ - !defined(ENABLE_ROCSOLVER_BACKEND) +#if !defined(ENABLE_ROCBLAS_BACKEND) && !defined(ENABLE_ROCRAND_BACKEND) && \ + !defined(ENABLE_ROCSOLVER_BACKEND) && !defined(ENABLE_PORTBLAS_BACKEND_AMD_GPU) && \ + !defined(ENABLE_ROCFFT_BACKEND) && !defined(ENABLE_PORTFFT_BACKEND) if (dev.is_gpu() && vendor_id == AMD_ID) continue; #endif @@ -147,7 +151,8 @@ int main(int argc, char** argv) { #endif } -#if defined(ENABLE_MKLCPU_BACKEND) || defined(ENABLE_NETLIB_BACKEND) +#if defined(ENABLE_MKLCPU_BACKEND) || defined(ENABLE_NETLIB_BACKEND) || \ + defined(ENABLE_PORTBLAS_BACKEND_INTEL_CPU) #ifdef __HIPSYCL__ local_devices.push_back(sycl::device(sycl::cpu_selector())); #else diff --git a/tests/unit_tests/rng/CMakeLists.txt b/tests/unit_tests/rng/CMakeLists.txt index 46bf53bcf..a2f077d35 100644 --- a/tests/unit_tests/rng/CMakeLists.txt +++ b/tests/unit_tests/rng/CMakeLists.txt @@ -17,5 +17,6 @@ # SPDX-License-Identifier: Apache-2.0 #=============================================================================== +add_subdirectory(device) add_subdirectory(service) add_subdirectory(statistics_check) diff --git a/tests/unit_tests/rng/device/CMakeLists.txt b/tests/unit_tests/rng/device/CMakeLists.txt new file mode 100644 index 000000000..e3f36d972 --- /dev/null +++ b/tests/unit_tests/rng/device/CMakeLists.txt @@ -0,0 +1,21 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +add_subdirectory(moments) +add_subdirectory(service) diff --git a/tests/unit_tests/rng/device/include/moments.hpp b/tests/unit_tests/rng/device/include/moments.hpp new file mode 100644 index 000000000..8acf20bf9 --- /dev/null +++ b/tests/unit_tests/rng/device/include/moments.hpp @@ -0,0 +1,121 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +/* +* +* Content: +* oneapi::mkl::rng::device:: distributions moments test (SYCL interface) +* +*******************************************************************************/ + +#ifndef _RNG_DEVICE_DISTR_MOMENTS_TEST_HPP_ +#define _RNG_DEVICE_DISTR_MOMENTS_TEST_HPP_ + +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/rng/device.hpp" + +#include "rng_device_test_common.hpp" + +template +class moments_test { +public: + template + void operator()(Queue queue) { + // Note: the following methods of discrete distributions require double precision support + if ((std::is_same_v< + Distribution, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>> || + std::is_same_v< + Distribution, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>> || + std::is_same_v> || + std::is_same_v< + Distribution, + oneapi::mkl::rng::device::poisson< + std::int32_t, oneapi::mkl::rng::device::poisson_method::devroye>>)&&!queue + .get_device() + .has(sycl::aspect::fp64)) { + status = test_skipped; + return; + } + using Type = typename Distribution::result_type; + // prepare array for random numbers + std::vector r(N_GEN); + + try { + sycl::range<1> range(N_GEN / Engine::vec_size); + + sycl::buffer buf(r); + auto event = queue.submit([&](sycl::handler& cgh) { + sycl::accessor acc(buf, cgh, sycl::write_only); + cgh.parallel_for(range, [=](sycl::item<1> item) { + size_t id = item.get_id(0); + auto multiplier = Engine::vec_size; + if constexpr (std::is_same_v>) + multiplier *= 2; + Engine engine(SEED, id * multiplier); + Distribution distr; + auto res = oneapi::mkl::rng::device::generate(distr, engine); + if constexpr (Engine::vec_size == 1) { + acc[id] = res; + } + else { + res.store(id, get_multi_ptr(acc)); + } + }); + }); + event.wait_and_throw(); + } + catch (const oneapi::mkl::unimplemented& e) { + status = test_skipped; + return; + } + catch (sycl::exception const& e) { + std::cout << "SYCL exception during generation" << std::endl + << e.what() << std::endl + << "Error code: " << get_error_code(e) << std::endl; + status = test_failed; + return; + } + + // validation (statistics check is turned out for mcg59) + if constexpr (!std::is_same>::value) { + statistics_device stat; + status = stat.check(r, Distribution{}); + } + return; + } + + int status = test_passed; +}; + +#endif // _RNG_DEVICE_DISTR_MOMENTS_TEST_HPP_ diff --git a/tests/unit_tests/rng/device/include/rng_device_test_common.hpp b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp new file mode 100644 index 000000000..6b014f0ec --- /dev/null +++ b/tests/unit_tests/rng/device/include/rng_device_test_common.hpp @@ -0,0 +1,342 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _RNG_DEVICE_TEST_COMMON_HPP__ +#define _RNG_DEVICE_TEST_COMMON_HPP__ + +#include +#include + +#include "test_helper.hpp" + +#define SEED 777 +#define N_GEN 960 + +// Defines for skip_ahead and leapfrog tests +#define N_ENGINES 5 +#define N_PORTION 100 +#define N_GEN_SERVICE (N_ENGINES * N_PORTION) + +// defines for skip_ahead_ex tests +#define N_SKIP ((std::uint64_t)pow(2, 62)) +#define SKIP_TIMES ((std::int32_t)pow(2, 14)) +#define NUM_TO_SKIP \ + { 0, (std::uint64_t)pow(2, 12) } + +// Correctness checking. +static inline bool check_equal_device(float x, float x_ref) { + float bound = std::numeric_limits::epsilon(); + float aerr = std::abs(x - x_ref); + return (aerr <= bound); +} + +static inline bool check_equal_device(double x, double x_ref) { + double bound = std::numeric_limits::epsilon(); + double aerr = std::abs(x - x_ref); + return (aerr <= bound); +} + +static inline bool check_equal_device(std::uint32_t x, std::uint32_t x_ref) { + return x == x_ref; +} + +static inline bool check_equal_device(std::uint64_t x, std::uint64_t x_ref) { + return x == x_ref; +} + +template +static inline bool check_equal_vector_device(std::vector& r1, + std::vector& r2) { + bool good = true; + for (int i = 0; i < r1.size(); i++) { + if (!check_equal_device(r1[i], r2[i])) { + good = false; + break; + } + } + return good; +} + +template +class rng_device_test { +public: + // method to call any tests, switch between rt and ct + template + int operator()(sycl::device* dev, Args... args) { + auto exception_handler = [](sycl::exception_list exceptions) { + for (std::exception_ptr const& e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (sycl::exception const& e) { + std::cout << "Caught asynchronous SYCL exception during ASUM:\n" + << e.what() << std::endl; + print_error_code(e); + } + } + }; + + sycl::queue queue(*dev, exception_handler); + + test_(queue, args...); + + return test_.status; + } + +protected: + Test test_; +}; + +template +struct has_member_code_meta : std::false_type {}; + +template +struct has_member_code_meta().get_multi_ptr())>> + : std::true_type {}; + +template ::value>::type* = nullptr> +auto get_multi_ptr(T acc) { +#ifndef __HIPSYCL__ + return acc.get_multi_ptr(); +#else + return acc.get_pointer(); +#endif +}; + +template ::value>::type* = nullptr> +auto get_multi_ptr(T acc) { +#ifndef __HIPSYCL__ + return acc.template get_multi_ptr(); +#else + return acc.get_pointer(); +#endif +}; + +template +auto get_error_code(T x) { + return x.code().value(); +}; + +template +bool compare_moments(const std::vector& r, double tM, double tD, double tQ) { + double tD2; + double sM, sD; + double sum, sum2; + double n, s; + double DeltaM, DeltaD; + + // sample moments + sum = 0.0; + sum2 = 0.0; + for (int i = 0; i < N_GEN; i++) { + sum += (double)r[i]; + sum2 += (double)r[i] * (double)r[i]; + } + sM = sum / ((double)N_GEN); + sD = sum2 / (double)N_GEN - (sM * sM); + + // Comparison of theoretical and sample moments + n = (double)N_GEN; + tD2 = tD * tD; + s = ((tQ - tD2) / n) - (2 * (tQ - 2 * tD2) / (n * n)) + ((tQ - 3 * tD2) / (n * n * n)); + + DeltaM = (tM - sM) / std::sqrt(tD / n); + DeltaD = (tD - sD) / std::sqrt(s); + if (fabs(DeltaM) > 3.0 || fabs(DeltaD) > 10.0) { + std::cout << "Error: sample moments (mean=" << sM << ", variance=" << sD + << ") disagree with theory (mean=" << tM << ", variance=" << tD << ")" + << " N_GEN = " << N_GEN << std::endl; + return false; + } + return true; +} + +template +struct statistics_device {}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { + double tM, tD, tQ; + Fp a = distr.a(); + Fp b = distr.b(); + + // Theoretical moments + tM = (b + a) / 2.0; + tD = ((b - a) * (b - a)) / 12.0; + tQ = ((b - a) * (b - a) * (b - a) * (b - a)) / 80.0; + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { + double tM, tD, tQ; + double a = distr.a(); + double b = distr.b(); + + // Theoretical moments + tM = (a + b - 1.0) / 2.0; + tD = ((b - a) * (b - a) - 1.0) / 12.0; + tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) + + (7.0 / 240.0); + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform& distr) { + double tM, tD, tQ; + double a = distr.a(); + double b = distr.b(); + + // Theoretical moments + tM = (a + b - 1.0) / 2.0; + tD = ((b - a) * (b - a) - 1.0) / 12.0; + tQ = (((b - a) * (b - a)) * ((1.0 / 80.0) * (b - a) * (b - a) - (1.0 / 24.0))) + + (7.0 / 240.0); + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::gaussian& distr) { + double tM, tD, tQ; + Fp a = distr.mean(); + Fp sigma = distr.stddev(); + + // Theoretical moments + tM = a; + tD = sigma * sigma; + tQ = 720.0 * sigma * sigma * sigma * sigma; + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::lognormal& distr) { + double tM, tD, tQ; + Fp a = distr.m(); + Fp b = distr.displ(); + Fp sigma = distr.s(); + Fp beta = distr.scale(); + + // Theoretical moments + tM = b + beta * std::exp(a + sigma * sigma * 0.5); + tD = beta * beta * std::exp(2.0 * a + sigma * sigma) * (std::exp(sigma * sigma) - 1.0); + tQ = beta * beta * beta * beta * std::exp(4.0 * a + 2.0 * sigma * sigma) * + (std::exp(6.0 * sigma * sigma) - 4.0 * std::exp(3.0 * sigma * sigma) + + 6.0 * std::exp(sigma * sigma) - 3.0); + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::exponential& distr) { + double tM, tD, tQ; + Fp a = distr.a(); + Fp beta = distr.beta(); + + tM = a + beta; + tD = beta * beta; + tQ = 9.0 * beta * beta * beta * beta; + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::poisson& distr) { + double tM, tD, tQ; + double lambda = distr.lambda(); + + tM = lambda; + tD = lambda; + tQ = 4 * lambda * lambda + lambda; + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::bernoulli& distr) { + double tM, tD, tQ; + double p = static_cast(distr.p()); + + tM = p; + tD = p * (1.0 - p); + tQ = p * (1.0 - 4.0 * p + 6.0 * p * p - 3.0 * p * p * p); + + return compare_moments(r, tM, tD, tQ); + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::bits& distr) { + return true; + } +}; + +template +struct statistics_device> { + template + bool check(const std::vector& r, + const oneapi::mkl::rng::device::uniform_bits& distr) { + return true; + } +}; + +template +struct is_mcg59 : std::false_type {}; + +template +struct is_mcg59> : std::true_type {}; + +#endif // _RNG_DEVICE_TEST_COMMON_HPP__ diff --git a/tests/unit_tests/rng/device/include/skip_ahead_test.hpp b/tests/unit_tests/rng/device/include/skip_ahead_test.hpp new file mode 100644 index 000000000..0b3bcf8a7 --- /dev/null +++ b/tests/unit_tests/rng/device/include/skip_ahead_test.hpp @@ -0,0 +1,178 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +/* +* +* Content: +* oneapi::mkl::rng::device:: engines skip_ahead and skip_ahead_ex tests +* (SYCL interface) +* +*******************************************************************************/ + +#ifndef _RNG_DEVICE_SKIP_AHEAD_TEST_HPP__ +#define _RNG_DEVICE_SKIP_AHEAD_TEST_HPP__ + +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl/rng/device.hpp" + +#include "rng_device_test_common.hpp" + +template +class skip_ahead_test { +public: + template + void operator()(Queue queue) { + using UIntType = std::conditional_t::value, std::uint64_t, std::uint32_t>; + + std::vector r(N_GEN); + std::vector r_ref(N_GEN); + + try { + sycl::range<1> range(N_GEN / Engine::vec_size); + + sycl::buffer buf(r); + auto event = queue.submit([&](sycl::handler& cgh) { + sycl::accessor acc(buf, cgh, sycl::write_only); + cgh.parallel_for(range, [=](sycl::item<1> item) { + size_t id = item.get_id(0); + Engine engine(SEED); + oneapi::mkl::rng::device::skip_ahead(engine, id * Engine::vec_size); + oneapi::mkl::rng::device::bits distr; + auto res = oneapi::mkl::rng::device::generate(distr, engine); + if constexpr (Engine::vec_size == 1) { + acc[id] = res; + } + else { + res.store(id, get_multi_ptr(acc)); + } + }); + }); + event.wait_and_throw(); + } + catch (const oneapi::mkl::unimplemented& e) { + status = test_skipped; + return; + } + catch (sycl::exception const& e) { + std::cout << "SYCL exception during generation" << std::endl + << e.what() << std::endl + << "Error code: " << get_error_code(e) << std::endl; + status = test_failed; + return; + } + + // validation + Engine engine(SEED); + oneapi::mkl::rng::device::bits distr; + for (int i = 0; i < N_GEN; i += Engine::vec_size) { + auto res = oneapi::mkl::rng::device::generate(distr, engine); + if constexpr (Engine::vec_size == 1) { + r_ref[i] = res; + } + else { + for (int j = 0; j < Engine::vec_size; ++j) { + r_ref[i + j] = res[j]; + } + } + } + + status = check_equal_vector_device(r, r_ref); + } + + int status = test_passed; +}; + +template +class skip_ahead_ex_test { +public: + template + void operator()(Queue queue) { + std::vector r(N_GEN); + std::vector r_ref(N_GEN); + + try { + sycl::range<1> range(N_GEN / Engine::vec_size); + + sycl::buffer buf(r); + std::uint64_t skip_num = (std::uint64_t)pow(2, 12); + auto event = queue.submit([&](sycl::handler& cgh) { + sycl::accessor acc(buf, cgh, sycl::write_only); + cgh.parallel_for(range, [=](sycl::item<1> item) { + size_t id = item.get_id(0); + Engine engine(SEED); + oneapi::mkl::rng::device::skip_ahead(engine, + { id * Engine::vec_size, skip_num }); + oneapi::mkl::rng::device::bits<> distr; + auto res = oneapi::mkl::rng::device::generate(distr, engine); + if constexpr (Engine::vec_size == 1) { + acc[id] = res; + } + else { + res.store(id, get_multi_ptr(acc)); + } + }); + }); + event.wait_and_throw(); + } + catch (const oneapi::mkl::unimplemented& e) { + status = test_skipped; + return; + } + catch (sycl::exception const& e) { + std::cout << "SYCL exception during generation" << std::endl + << e.what() << std::endl + << "Error code: " << get_error_code(e) << std::endl; + status = test_failed; + return; + } + + // validation + Engine engine(SEED); + for (int j = 0; j < SKIP_TIMES; j++) { + oneapi::mkl::rng::device::skip_ahead(engine, N_SKIP); + } + oneapi::mkl::rng::device::bits<> distr; + for (int i = 0; i < N_GEN; i += Engine::vec_size) { + auto res = oneapi::mkl::rng::device::generate(distr, engine); + if constexpr (Engine::vec_size == 1) { + r_ref[i] = res; + } + else { + for (int j = 0; j < Engine::vec_size; ++j) { + r_ref[i + j] = res[j]; + } + } + } + + status = check_equal_vector_device(r, r_ref); + } + + int status = test_passed; +}; + +#endif // _RNG_DEVICE_SKIP_AHEAD_TEST_HPP__ diff --git a/tests/unit_tests/rng/device/moments/CMakeLists.txt b/tests/unit_tests/rng/device/moments/CMakeLists.txt new file mode 100644 index 000000000..2da8033bf --- /dev/null +++ b/tests/unit_tests/rng/device/moments/CMakeLists.txt @@ -0,0 +1,40 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +# Build object from all test sources +set(MOMENTS_DEVICE_TESTS_SOURCES "moments.cpp") + +add_library(rng_device_moments_ct OBJECT ${MOMENTS_DEVICE_TESTS_SOURCES}) +target_compile_options(rng_device_moments_ct PRIVATE -DNOMINMAX) +target_include_directories(rng_device_moments_ct + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include + PUBLIC ${CMAKE_BINARY_DIR}/bin +) +if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET rng_device_moments_ct SOURCES ${MOMENTS_DEVICE_TESTS_SOURCES}) +else() + target_link_libraries(rng_device_moments_ct PUBLIC ONEMKL::SYCL::SYCL) +endif() + +if(NOT ${ONEMKL_SYCL_IMPLEMENTATION} STREQUAL "hipsycl") + target_link_options(rng_device_moments_ct PUBLIC -fsycl -fsycl-device-code-split=per_kernel) +endif() diff --git a/tests/unit_tests/rng/device/moments/moments.cpp b/tests/unit_tests/rng/device/moments/moments.cpp new file mode 100644 index 000000000..36ce38ee8 --- /dev/null +++ b/tests/unit_tests/rng/device/moments/moments.cpp @@ -0,0 +1,1050 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "moments.hpp" + +#include + +extern std::vector devices; + +namespace { + +class Philox4x32x10UniformStdDeviceMomentsTests : public ::testing::TestWithParam {}; + +class Philox4x32x10UniformAccDeviceMomentsTests : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformStdDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformAccDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10UniformStdDeviceMomentsTestsSuite, + Philox4x32x10UniformStdDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10UniformAccDeviceMomentsTestsSuite, + Philox4x32x10UniformAccDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +class Mrg32k3aUniformStdDeviceMomentsTests : public ::testing::TestWithParam {}; + +class Mrg32k3aUniformAccDeviceMomentsTests : public ::testing::TestWithParam {}; + +TEST_P(Mrg32k3aUniformStdDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mrg32k3aUniformStdDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mrg32k3aUniformStdDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mrg32k3aUniformStdDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mrg32k3aUniformAccDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mrg32k3aUniformAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mrg32k3aUniformAccDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mrg32k3aUniformAccDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Mrg32k3aUniformStdDeviceMomentsTestsSuite, + Mrg32k3aUniformStdDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Mrg32k3aUniformAccDeviceMomentsTestsSuite, + Mrg32k3aUniformAccDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +class Mcg31m1UniformStdDeviceMomentsTests : public ::testing::TestWithParam {}; + +class Mcg31m1UniformAccDeviceMomentsTests : public ::testing::TestWithParam {}; + +TEST_P(Mcg31m1UniformStdDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg31m1UniformStdDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg31m1UniformStdDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg31m1UniformStdDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg31m1UniformAccDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg31m1UniformAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg31m1UniformAccDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg31m1UniformAccDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Mcg31m1UniformStdDeviceMomentsTestsSuite, + Mcg31m1UniformStdDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Mcg31m1UniformAccDeviceMomentsTestsSuite, + Mcg31m1UniformAccDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +class Mcg59UniformStdDeviceMomentsTests : public ::testing::TestWithParam {}; + +class Mcg59UniformAccDeviceMomentsTests : public ::testing::TestWithParam {}; + +TEST_P(Mcg59UniformStdDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg59UniformStdDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg59UniformStdDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg59UniformStdDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::standard>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg59UniformAccDeviceMomentsTests, RealSinglePrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + float, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg59UniformAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform< + double, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg59UniformAccDeviceMomentsTests, IntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::int32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mcg59UniformAccDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::uniform< + std::uint32_t, oneapi::mkl::rng::device::uniform_method::accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Mcg59UniformStdDeviceMomentsTestsSuite, Mcg59UniformStdDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Mcg59UniformAccDeviceMomentsTestsSuite, Mcg59UniformAccDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +class Philox4x32x10BitsDeviceMomentsTests : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10BitsDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test, + oneapi::mkl::rng::device::bits>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bits>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bits>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10BitsDeviceMomentsTestsSuite, + Philox4x32x10BitsDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +class Philox4x32x10UniformBitsDeviceMomentsTests : public ::testing::TestWithParam { +}; + +TEST_P(Philox4x32x10UniformBitsDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform_bits>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform_bits>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform_bits>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10UniformBitsDeviceMomentsTests, UnsignedLongIntegerPrecision) { + rng_device_test, + oneapi::mkl::rng::device::uniform_bits>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform_bits>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::uniform_bits>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10UniformBitsDeviceMomentsTestsSuite, + Philox4x32x10UniformBitsDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +class Philox4x32x10GaussianBoxMuller2DeviceMomentsTests + : public ::testing::TestWithParam {}; + +// implementation uses double precision for accuracy +TEST_P(Philox4x32x10GaussianBoxMuller2DeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gaussian< + float, oneapi::mkl::rng::device::gaussian_method::box_muller2>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gaussian< + float, oneapi::mkl::rng::device::gaussian_method::box_muller2>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gaussian< + float, oneapi::mkl::rng::device::gaussian_method::box_muller2>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gaussian< + double, oneapi::mkl::rng::device::gaussian_method::box_muller2>>> + test4; + EXPECT_TRUEORSKIP((test4(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gaussian< + double, oneapi::mkl::rng::device::gaussian_method::box_muller2>>> + test5; + EXPECT_TRUEORSKIP((test5(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::gaussian< + double, oneapi::mkl::rng::device::gaussian_method::box_muller2>>> + test6; + EXPECT_TRUEORSKIP((test6(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10GaussianBoxMuller2DeviceMomentsTestsSuite, + Philox4x32x10GaussianBoxMuller2DeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +class Philox4x32x10LognormalBoxMuller2DeviceMomentsTests + : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10LognormalBoxMuller2DeviceMomentsTests, RealSinglePrecision) { + rng_device_test< + moments_test, + oneapi::mkl::rng::device::lognormal< + float, oneapi::mkl::rng::device::lognormal_method::box_muller2>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::lognormal< + float, oneapi::mkl::rng::device::lognormal_method::box_muller2>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::lognormal< + float, oneapi::mkl::rng::device::lognormal_method::box_muller2>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10LognormalBoxMuller2DeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::lognormal< + double, oneapi::mkl::rng::device::lognormal_method::box_muller2>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::lognormal< + double, oneapi::mkl::rng::device::lognormal_method::box_muller2>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::lognormal< + double, oneapi::mkl::rng::device::lognormal_method::box_muller2>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10LognormalBoxMuller2DeviceMomentsTestsSuite, + Philox4x32x10LognormalBoxMuller2DeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +class Philox4x32x10ExponentialIcdfDeviceMomentsTests + : public ::testing::TestWithParam {}; + +class Philox4x32x10ExponentialIcdfAccDeviceMomentsTests + : public ::testing::TestWithParam {}; + +// implementation uses double precision for accuracy +TEST_P(Philox4x32x10ExponentialIcdfDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test, + oneapi::mkl::rng::device::exponential< + float, oneapi::mkl::rng::device::exponential_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::exponential< + float, oneapi::mkl::rng::device::exponential_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::exponential< + float, oneapi::mkl::rng::device::exponential_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::exponential< + double, oneapi::mkl::rng::device::exponential_method::icdf>>> + test4; + EXPECT_TRUEORSKIP((test4(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::exponential< + double, oneapi::mkl::rng::device::exponential_method::icdf>>> + test5; + EXPECT_TRUEORSKIP((test5(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::exponential< + double, oneapi::mkl::rng::device::exponential_method::icdf>>> + test6; + EXPECT_TRUEORSKIP((test6(GetParam()))); +} + +// implementation uses double precision for accuracy +TEST_P(Philox4x32x10ExponentialIcdfAccDeviceMomentsTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + + rng_device_test< + moments_test, + oneapi::mkl::rng::device::exponential< + float, oneapi::mkl::rng::device::exponential_method::icdf_accurate>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::exponential< + float, oneapi::mkl::rng::device::exponential_method::icdf_accurate>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::exponential< + float, oneapi::mkl::rng::device::exponential_method::icdf_accurate>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::exponential< + double, oneapi::mkl::rng::device::exponential_method::icdf_accurate>>> + test4; + EXPECT_TRUEORSKIP((test4(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::exponential< + double, oneapi::mkl::rng::device::exponential_method::icdf_accurate>>> + test5; + EXPECT_TRUEORSKIP((test5(GetParam()))); + rng_device_test< + moments_test, + oneapi::mkl::rng::device::exponential< + double, oneapi::mkl::rng::device::exponential_method::icdf_accurate>>> + test6; + EXPECT_TRUEORSKIP((test6(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10ExponentialIcdfDeviceMomentsTestsSuite, + Philox4x32x10ExponentialIcdfDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10ExponentialIcdfAccDeviceMomentsTestsSuite, + Philox4x32x10ExponentialIcdfAccDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +class Philox4x32x10PoissonDevroyeDeviceMomentsTests + : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, IntegerPrecision) { + rng_device_test, + oneapi::mkl::rng::device::poisson< + int32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::poisson< + int32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::poisson< + int32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10PoissonDevroyeDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test, + oneapi::mkl::rng::device::poisson< + uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::poisson< + uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::poisson< + uint32_t, oneapi::mkl::rng::device::poisson_method::devroye>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10PoissonDevroyeDeviceMomentsTestsSuite, + Philox4x32x10PoissonDevroyeDeviceMomentsTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +class Philox4x32x10BernoulliIcdfDeviceMomentsTests + : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, IntegerPrecision) { + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + int32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10BernoulliIcdfDeviceMomentsTests, UnsignedIntegerPrecision) { + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test, + oneapi::mkl::rng::device::bernoulli< + uint32_t, oneapi::mkl::rng::device::bernoulli_method::icdf>>> + test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10BernoulliIcdfDeviceMomentsTestsSuite, + Philox4x32x10BernoulliIcdfDeviceMomentsTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/rng/device/service/CMakeLists.txt b/tests/unit_tests/rng/device/service/CMakeLists.txt new file mode 100644 index 000000000..03d960e1a --- /dev/null +++ b/tests/unit_tests/rng/device/service/CMakeLists.txt @@ -0,0 +1,40 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +# Build object from all test sources +set(SERVICE_DEVICE_TESTS_SOURCES "skip_ahead.cpp") + +add_library(rng_device_service_ct OBJECT ${SERVICE_DEVICE_TESTS_SOURCES}) +target_compile_options(rng_device_service_ct PRIVATE -DNOMINMAX) +target_include_directories(rng_device_service_ct + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../../include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include + PUBLIC ${CMAKE_BINARY_DIR}/bin +) +if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET rng_device_service_ct SOURCES ${SERVICE_DEVICE_TESTS_SOURCES}) +else() + target_link_libraries(rng_device_service_ct PUBLIC ONEMKL::SYCL::SYCL) +endif() + +if(NOT ${ONEMKL_SYCL_IMPLEMENTATION} STREQUAL "hipsycl") + target_link_options(rng_device_service_ct PUBLIC -fsycl -fsycl-device-code-split=per_kernel) +endif() diff --git a/tests/unit_tests/rng/device/service/skip_ahead.cpp b/tests/unit_tests/rng/device/service/skip_ahead.cpp new file mode 100644 index 000000000..a5dfe0da8 --- /dev/null +++ b/tests/unit_tests/rng/device/service/skip_ahead.cpp @@ -0,0 +1,113 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include "skip_ahead_test.hpp" + +#include + +extern std::vector devices; + +namespace { + +class Philox4x32x10DeviceSkipAheadTests : public ::testing::TestWithParam {}; + +class Philox4x32x10DeviceSkipAheadExTests : public ::testing::TestWithParam {}; + +TEST_P(Philox4x32x10DeviceSkipAheadTests, BinaryPrecision) { + rng_device_test>> test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test>> test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test>> test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Philox4x32x10DeviceSkipAheadExTests, BinaryPrecision) { + rng_device_test>> test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test>> test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test>> test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10DeviceSkipAheadTestsSuite, Philox4x32x10DeviceSkipAheadTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Philox4x32x10DeviceSkipAheadExTestsSuite, + Philox4x32x10DeviceSkipAheadExTests, ::testing::ValuesIn(devices), + ::DeviceNamePrint()); + +class Mrg32k3aDeviceSkipAheadTests : public ::testing::TestWithParam {}; + +class Mrg32k3aDeviceSkipAheadExTests : public ::testing::TestWithParam {}; + +TEST_P(Mrg32k3aDeviceSkipAheadTests, BinaryPrecision) { + rng_device_test>> test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test>> test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test>> test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +TEST_P(Mrg32k3aDeviceSkipAheadExTests, BinaryPrecision) { + rng_device_test>> test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test>> test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test>> test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Mrg32k3aDeviceSkipAheadTestsSuite, Mrg32k3aDeviceSkipAheadTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +INSTANTIATE_TEST_SUITE_P(Mrg32k3aDeviceSkipAheadExTestsSuite, Mrg32k3aDeviceSkipAheadExTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +class Mcg31m1DeviceSkipAheadTests : public ::testing::TestWithParam {}; + +TEST_P(Mcg31m1DeviceSkipAheadTests, BinaryPrecision) { + rng_device_test>> test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test>> test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test>> test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Mcg31m1DeviceSkipAheadTestsSuite, Mcg31m1DeviceSkipAheadTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +class Mcg59DeviceSkipAheadTests : public ::testing::TestWithParam {}; + +TEST_P(Mcg59DeviceSkipAheadTests, BinaryPrecision) { + rng_device_test>> test1; + EXPECT_TRUEORSKIP((test1(GetParam()))); + rng_device_test>> test2; + EXPECT_TRUEORSKIP((test2(GetParam()))); + rng_device_test>> test3; + EXPECT_TRUEORSKIP((test3(GetParam()))); +} + +INSTANTIATE_TEST_SUITE_P(Mcg59DeviceSkipAheadTestsSuite, Mcg59DeviceSkipAheadTests, + ::testing::ValuesIn(devices), ::DeviceNamePrint()); + +} // namespace diff --git a/tests/unit_tests/rng/include/engines_api_tests.hpp b/tests/unit_tests/rng/include/engines_api_tests.hpp index 8b7004f70..2469c3023 100644 --- a/tests/unit_tests/rng/include/engines_api_tests.hpp +++ b/tests/unit_tests/rng/include/engines_api_tests.hpp @@ -63,6 +63,7 @@ class engines_constructors_test { oneapi::mkl::rng::generate(distr, engine2, N_GEN, r2_buffer); oneapi::mkl::rng::generate(distr, engine3, N_GEN, r3_buffer); oneapi::mkl::rng::generate(distr, engine4, N_GEN, r4_buffer); + QUEUE_WAIT(queue); } catch (const oneapi::mkl::unimplemented& e) { status = test_skipped; @@ -118,6 +119,7 @@ class engines_copy_test { oneapi::mkl::rng::generate(distr, engine3, N_GEN, r2_buffer); oneapi::mkl::rng::generate(distr, engine4, N_GEN, r3_buffer); } + QUEUE_WAIT(queue); } catch (const oneapi::mkl::unimplemented& e) { status = test_skipped; diff --git a/tests/unit_tests/rng/include/rng_test_common.hpp b/tests/unit_tests/rng/include/rng_test_common.hpp index 171ef7228..d01b04cce 100644 --- a/tests/unit_tests/rng/include/rng_test_common.hpp +++ b/tests/unit_tests/rng/include/rng_test_common.hpp @@ -112,7 +112,7 @@ class rng_test { #ifdef CALL_RT_API test_(queue, args...); #else - TEST_RUN_CT_SELECT(queue, test_, args...); + TEST_RUN_RNG_CT_SELECT(queue, test_, args...); #endif return test_.status; @@ -122,4 +122,10 @@ class rng_test { Test test_; }; +#ifdef CALL_RT_API +#define QUEUE_WAIT(q) q.wait() +#else +#define QUEUE_WAIT(q) q.get_queue().wait() +#endif + #endif // _RNG_TEST_COMMON_HPP__ diff --git a/tests/unit_tests/rng/include/skip_ahead_test.hpp b/tests/unit_tests/rng/include/skip_ahead_test.hpp index b0f827a13..efec71dde 100644 --- a/tests/unit_tests/rng/include/skip_ahead_test.hpp +++ b/tests/unit_tests/rng/include/skip_ahead_test.hpp @@ -67,6 +67,7 @@ class skip_ahead_test { for (int i = 0; i < N_ENGINES; i++) { oneapi::mkl::rng::generate(distr, *(engines[i]), N_PORTION, r_buffers[i]); } + QUEUE_WAIT(queue); // Clear memory for (int i = 0; i < N_ENGINES; i++) { @@ -118,6 +119,7 @@ class skip_ahead_ex_test { oneapi::mkl::rng::generate(distr, engine1, N_GEN, r1_buffer); oneapi::mkl::rng::generate(distr, engine2, N_GEN, r2_buffer); + QUEUE_WAIT(queue); } catch (const oneapi::mkl::unimplemented& e) { status = test_skipped; diff --git a/tests/unit_tests/rng/include/statistics_check_test.hpp b/tests/unit_tests/rng/include/statistics_check_test.hpp index 52f26f6f0..14a637d7a 100644 --- a/tests/unit_tests/rng/include/statistics_check_test.hpp +++ b/tests/unit_tests/rng/include/statistics_check_test.hpp @@ -63,6 +63,7 @@ class statistics_test { Engine engine(queue, SEED); Distr distr(args...); oneapi::mkl::rng::generate(distr, engine, n_gen, r_buffer); + QUEUE_WAIT(queue); } catch (sycl::exception const& e) { std::cout << "Caught synchronous SYCL exception during generation:\n" diff --git a/tests/unit_tests/rng/statistics_check/gaussian.cpp b/tests/unit_tests/rng/statistics_check/gaussian.cpp index 1da91b595..ed63f3221 100644 --- a/tests/unit_tests/rng/statistics_check/gaussian.cpp +++ b/tests/unit_tests/rng/statistics_check/gaussian.cpp @@ -43,6 +43,8 @@ TEST_P(GaussianIcdfTest, RealSinglePrecision) { } TEST_P(GaussianIcdfTest, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test< statistics_test, oneapi::mkl::rng::philox4x32x10>> @@ -69,6 +71,8 @@ TEST_P(GaussianBoxmullerTest, RealSinglePrecision) { } TEST_P(GaussianBoxmullerTest, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> diff --git a/tests/unit_tests/rng/statistics_check/gaussian_usm.cpp b/tests/unit_tests/rng/statistics_check/gaussian_usm.cpp index d95ee6fea..a1d4d1b06 100644 --- a/tests/unit_tests/rng/statistics_check/gaussian_usm.cpp +++ b/tests/unit_tests/rng/statistics_check/gaussian_usm.cpp @@ -43,6 +43,8 @@ TEST_P(GaussianIcdfUsmTest, RealSinglePrecision) { } TEST_P(GaussianIcdfUsmTest, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> @@ -69,6 +71,8 @@ TEST_P(GaussianBoxmullerUsmTest, RealSinglePrecision) { } TEST_P(GaussianBoxmullerUsmTest, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> diff --git a/tests/unit_tests/rng/statistics_check/lognormal.cpp b/tests/unit_tests/rng/statistics_check/lognormal.cpp index 456eb6e05..5486202bb 100755 --- a/tests/unit_tests/rng/statistics_check/lognormal.cpp +++ b/tests/unit_tests/rng/statistics_check/lognormal.cpp @@ -43,6 +43,8 @@ TEST_P(LognormalIcdfTest, RealSinglePrecision) { } TEST_P(LognormalIcdfTest, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> @@ -69,6 +71,8 @@ TEST_P(LognormalBoxmullerTest, RealSinglePrecision) { } TEST_P(LognormalBoxmullerTest, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> diff --git a/tests/unit_tests/rng/statistics_check/lognormal_usm.cpp b/tests/unit_tests/rng/statistics_check/lognormal_usm.cpp index 1e648a6bb..d59d9458a 100755 --- a/tests/unit_tests/rng/statistics_check/lognormal_usm.cpp +++ b/tests/unit_tests/rng/statistics_check/lognormal_usm.cpp @@ -43,6 +43,8 @@ TEST_P(LognormalIcdfUsmTest, RealSinglePrecision) { } TEST_P(LognormalIcdfUsmTest, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> @@ -69,6 +71,8 @@ TEST_P(LognormalBoxmullerUsmTest, RealSinglePrecision) { } TEST_P(LognormalBoxmullerUsmTest, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> diff --git a/tests/unit_tests/rng/statistics_check/uniform.cpp b/tests/unit_tests/rng/statistics_check/uniform.cpp index 3e4f16878..eb11714e1 100644 --- a/tests/unit_tests/rng/statistics_check/uniform.cpp +++ b/tests/unit_tests/rng/statistics_check/uniform.cpp @@ -43,6 +43,8 @@ TEST_P(UniformStdTests, RealSinglePrecision) { } TEST_P(UniformStdTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> @@ -82,6 +84,8 @@ TEST_P(UniformAccurateTests, RealSinglePrecision) { } TEST_P(UniformAccurateTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> diff --git a/tests/unit_tests/rng/statistics_check/uniform_usm.cpp b/tests/unit_tests/rng/statistics_check/uniform_usm.cpp index 82eac9fcc..df4f7a764 100644 --- a/tests/unit_tests/rng/statistics_check/uniform_usm.cpp +++ b/tests/unit_tests/rng/statistics_check/uniform_usm.cpp @@ -43,6 +43,8 @@ TEST_P(UniformStdUsmTests, RealSinglePrecision) { } TEST_P(UniformStdUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> @@ -82,6 +84,8 @@ TEST_P(UniformAccurateUsmTests, RealSinglePrecision) { } TEST_P(UniformAccurateUsmTests, RealDoublePrecision) { + CHECK_DOUBLE_ON_DEVICE(GetParam()); + rng_test, oneapi::mkl::rng::philox4x32x10>> diff --git a/tests/unit_tests/sparse_blas/CMakeLists.txt b/tests/unit_tests/sparse_blas/CMakeLists.txt new file mode 100644 index 000000000..2c46cd38c --- /dev/null +++ b/tests/unit_tests/sparse_blas/CMakeLists.txt @@ -0,0 +1,20 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +add_subdirectory(source) diff --git a/tests/unit_tests/sparse_blas/include/sparse_reference.hpp b/tests/unit_tests/sparse_blas/include/sparse_reference.hpp new file mode 100644 index 000000000..ffb876f11 --- /dev/null +++ b/tests/unit_tests/sparse_blas/include/sparse_reference.hpp @@ -0,0 +1,297 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _SPARSE_REFERENCE_HPP__ +#define _SPARSE_REFERENCE_HPP__ + +#include +#include +#include + +#include "oneapi/mkl.hpp" + +#include "test_common.hpp" + +template +inline T conjugate(T) { + static_assert(false, "Unsupported type"); +} +template <> +inline float conjugate(float t) { + return t; +} +template <> +inline double conjugate(double t) { + return t; +} +template <> +inline std::complex conjugate(std::complex t) { + return std::conj(t); +} +template <> +inline std::complex conjugate(std::complex t) { + return std::conj(t); +} + +template +inline T opVal(const T t, const bool isConj) { + return (isConj ? conjugate(t) : t); +}; + +template +void do_csr_transpose(const oneapi::mkl::transpose opA, intType *ia_t, intType *ja_t, fpType *a_t, + intType a_nrows, intType a_ncols, intType a_ind, accIntType &ia, + accIntType &ja, accFpType &a, const bool structOnlyFlag = false) { + const bool isConj = (opA == oneapi::mkl::transpose::conjtrans); + + // initialize ia_t to zero + for (intType i = 0; i < a_ncols + 1; ++i) { + ia_t[i] = 0; + } + + // fill ia_t with counts of columns + for (intType i = 0; i < a_nrows; ++i) { + const intType st = ia[i] - a_ind; + const intType en = ia[i + 1] - a_ind; + for (intType j = st; j < en; ++j) { + const intType col = ja[j] - a_ind; + ia_t[col + 1]++; + } + } + // prefix sum to get official ia_t counts + ia_t[0] = a_ind; + for (intType i = 0; i < a_ncols; ++i) { + ia_t[i + 1] += ia_t[i]; + } + + // second pass through data to fill transpose structure + for (intType i = 0; i < a_nrows; ++i) { + const intType st = ia[i] - a_ind; + const intType en = ia[i + 1] - a_ind; + for (intType j = st; j < en; ++j) { + const intType col = ja[j] - a_ind; + const intType j_in_a_t = ia_t[col] - a_ind; + ia_t[col]++; + ja_t[j_in_a_t] = i + a_ind; + if (!structOnlyFlag) { + const fpType val = a[j]; + a_t[j_in_a_t] = opVal(val, isConj); + } + } + } + + // adjust ia_t back to original state after filling structure + for (intType i = a_ncols; i > 0; --i) { + ia_t[i] = ia_t[i - 1]; + } + ia_t[0] = a_ind; +} + +// Transpose the given sparse matrix if needed +template +auto sparse_transpose_if_needed(const intType *ia, const intType *ja, const fpType *a, + intType a_nrows, intType a_ncols, std::size_t nnz, intType a_ind, + oneapi::mkl::transpose transpose_val) { + std::vector iopa; + std::vector jopa; + std::vector opa; + if (transpose_val == oneapi::mkl::transpose::nontrans) { + iopa.assign(ia, ia + a_nrows + 1); + jopa.assign(ja, ja + nnz); + opa.assign(a, a + nnz); + } + else if (transpose_val == oneapi::mkl::transpose::trans || + transpose_val == oneapi::mkl::transpose::conjtrans) { + iopa.resize(static_cast(a_ncols + 1)); + jopa.resize(nnz); + opa.resize(nnz); + do_csr_transpose(transpose_val, iopa.data(), jopa.data(), opa.data(), a_nrows, a_ncols, + a_ind, ia, ja, a); + } + else { + throw std::runtime_error("unsupported transpose_val=" + + std::to_string(static_cast(transpose_val))); + } + return std::make_tuple(iopa, jopa, opa); +} + +template +auto dense_transpose_if_needed(const fpType *x, std::size_t outer_size, std::size_t inner_size, + std::size_t ld, oneapi::mkl::transpose transpose_val) { + std::vector opx; + if (transpose_val == oneapi::mkl::transpose::nontrans) { + opx.assign(x, x + outer_size * ld); + } + else { + opx.resize(outer_size * ld); + for (std::size_t i = 0; i < outer_size; ++i) { + for (std::size_t j = 0; j < inner_size; ++j) { + opx[i + j * ld] = x[i * ld + j]; + } + } + } + return opx; +} + +/// Return the dense matrix A in row major layout. +/// Diagonal values are overwritten with 1s if diag_val is unit. +template +std::vector sparse_to_dense(const intType *ia, const intType *ja, const fpType *a, + std::size_t a_nrows, std::size_t a_ncols, intType a_ind, + oneapi::mkl::transpose transpose_val, + oneapi::mkl::diag diag_val) { + std::vector dense_a(a_nrows * a_ncols, fpType(0)); + for (std::size_t row = 0; row < a_nrows; row++) { + for (intType i = ia[row] - a_ind; i < ia[row + 1] - a_ind; i++) { + std::size_t iu = static_cast(i); + std::size_t col = static_cast(ja[iu] - a_ind); + std::size_t dense_a_idx = transpose_val != oneapi::mkl::transpose::nontrans + ? col * a_nrows + row + : row * a_ncols + col; + fpType val = a[iu]; + if constexpr (complex_info::is_complex) { + if (transpose_val == oneapi::mkl::transpose::conjtrans) { + val = std::conj(val); + } + } + dense_a[dense_a_idx] = val; + } + } + if (diag_val == oneapi::mkl::diag::unit) { + for (std::size_t i = 0; i < a_nrows; ++i) { + dense_a[i * a_ncols + i] = set_fp_value()(1.f, 0.f); + } + } + return dense_a; +} + +template +void prepare_reference_gemv_data(const intType *ia, const intType *ja, const fpType *a, + intType a_nrows, intType a_ncols, intType a_nnz, intType a_ind, + oneapi::mkl::transpose opA, fpType alpha, fpType beta, + const fpType *x, fpType *y_ref) { + std::size_t opa_nrows = + static_cast((opA == oneapi::mkl::transpose::nontrans) ? a_nrows : a_ncols); + const std::size_t nnz = static_cast(a_nnz); + auto [iopa, jopa, opa] = + sparse_transpose_if_needed(ia, ja, a, a_nrows, a_ncols, nnz, a_ind, opA); + + // + // do GEMV operation + // + // y_ref <- alpha * op(A) * x + beta * y_ref + // + for (std::size_t row = 0; row < opa_nrows; row++) { + fpType tmp = 0; + for (intType i = iopa[row] - a_ind; i < iopa[row + 1] - a_ind; i++) { + std::size_t iu = static_cast(i); + std::size_t x_ind = static_cast(jopa[iu] - a_ind); + tmp += opa[iu] * x[x_ind]; + } + + y_ref[row] = alpha * tmp + beta * y_ref[row]; + } +} + +template +void prepare_reference_gemm_data(const intType *ia, const intType *ja, const fpType *a, + intType a_nrows, intType a_ncols, intType c_ncols, intType a_nnz, + intType a_ind, oneapi::mkl::layout dense_matrix_layout, + oneapi::mkl::transpose opA, oneapi::mkl::transpose opB, + fpType alpha, fpType beta, intType ldb, intType ldc, + const fpType *b, fpType *c_ref) { + std::size_t opa_nrows = + static_cast((opA == oneapi::mkl::transpose::nontrans) ? a_nrows : a_ncols); + std::size_t opa_ncols = + static_cast((opA == oneapi::mkl::transpose::nontrans) ? a_ncols : a_nrows); + const std::size_t nnz = static_cast(a_nnz); + const std::size_t ldb_u = static_cast(ldb); + const std::size_t ldc_u = static_cast(ldc); + auto [iopa, jopa, opa] = + sparse_transpose_if_needed(ia, ja, a, a_nrows, a_ncols, nnz, a_ind, opA); + + std::size_t b_outer_size = static_cast(opa_ncols); + std::size_t b_inner_size = static_cast(c_ncols); + if (dense_matrix_layout == oneapi::mkl::layout::col_major) { + std::swap(b_outer_size, b_inner_size); + } + auto opb = dense_transpose_if_needed(b, b_outer_size, b_inner_size, ldb_u, opB); + + // + // do GEMM operation + // + // C <- alpha * opA(A) * opB(B) + beta * C + // + if (dense_matrix_layout == oneapi::mkl::layout::row_major) { + for (std::size_t row = 0; row < opa_nrows; row++) { + for (std::size_t col = 0; col < static_cast(c_ncols); col++) { + fpType tmp = 0; + for (std::size_t i = static_cast(iopa[row] - a_ind); + i < static_cast(iopa[row + 1] - a_ind); i++) { + tmp += opa[i] * opb[static_cast(jopa[i] - a_ind) * ldb_u + col]; + } + fpType &c = c_ref[row * ldc_u + col]; + c = alpha * tmp + beta * c; + } + } + } + else { + for (std::size_t col = 0; col < static_cast(c_ncols); col++) { + for (std::size_t row = 0; row < opa_nrows; row++) { + fpType tmp = 0; + for (std::size_t i = static_cast(iopa[row] - a_ind); + i < static_cast(iopa[row + 1] - a_ind); i++) { + tmp += opa[i] * opb[static_cast(jopa[i] - a_ind) + col * ldb_u]; + } + fpType &c = c_ref[row + col * ldc_u]; + c = alpha * tmp + beta * c; + } + } + } +} + +template +void prepare_reference_trsv_data(const intType *ia, const intType *ja, const fpType *a, intType m, + intType a_ind, oneapi::mkl::uplo uplo_val, + oneapi::mkl::transpose opA, oneapi::mkl::diag diag_val, + const fpType *x, fpType *y_ref) { + std::size_t mu = static_cast(m); + auto dense_a = sparse_to_dense(ia, ja, a, mu, mu, a_ind, opA, diag_val); + + // + // do TRSV operation + // + // y_ref <- op(A)^-1 * x + // + // Compute each element of the reference one after the other starting from 0 (resp. the end) for a lower (resp. upper) triangular matrix. + // A matrix is considered lowered if it is lower and not transposed or upper and transposed. + const bool is_lower = + (uplo_val == oneapi::mkl::uplo::lower) == (opA == oneapi::mkl::transpose::nontrans); + for (std::size_t row = 0; row < mu; row++) { + std::size_t uplo_row = is_lower ? row : (mu - 1 - row); + fpType rhs = x[uplo_row]; + for (std::size_t col = 0; col < row; col++) { + std::size_t uplo_col = is_lower ? col : (mu - 1 - col); + rhs -= dense_a[uplo_row * mu + uplo_col] * y_ref[uplo_col]; + } + y_ref[uplo_row] = rhs / dense_a[uplo_row * mu + uplo_row]; + } +} + +#endif // _SPARSE_REFERENCE_HPP__ diff --git a/tests/unit_tests/sparse_blas/include/test_common.hpp b/tests/unit_tests/sparse_blas/include/test_common.hpp new file mode 100644 index 000000000..fd1e91a47 --- /dev/null +++ b/tests/unit_tests/sparse_blas/include/test_common.hpp @@ -0,0 +1,286 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#ifndef _TEST_COMMON_HPP__ +#define _TEST_COMMON_HPP__ + +#include +#include +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "test_helper.hpp" + +// Sparse BLAS domain needs to call more functions per test so we use this macro helper to select between runtime and compile dispatch for each function +#ifdef CALL_RT_API +#define CALL_RT_OR_CT(FUNC, QUEUE, ...) FUNC(QUEUE, __VA_ARGS__) +#else +#define CALL_RT_OR_CT(FUNC, QUEUE, ...) TEST_RUN_CT_SELECT(QUEUE, FUNC, __VA_ARGS__); +#endif + +template +struct complex_info { + using real_type = T; + static const bool is_complex = false; +}; + +template +struct complex_info> { + using real_type = T; + static const bool is_complex = true; +}; + +void print_error_code(sycl::exception const &e); + +// Catch asynchronous exceptions. +struct exception_handler_t { + void operator()(sycl::exception_list exceptions) { + for (std::exception_ptr const &e : exceptions) { + try { + std::rethrow_exception(e); + } + catch (sycl::exception const &e) { + std::cout << "Caught asynchronous SYCL exception:\n" << e.what() << std::endl; + print_error_code(e); + } + } + } +}; + +// Use a unique_ptr to automatically free device memory on unique_ptr destruction. +template +auto malloc_device_uptr(sycl::queue q, std::size_t num_elts) { + struct Deleter { + sycl::queue q; + Deleter(sycl::queue _q) : q(_q) {} + void operator()(T *ptr) { + sycl::free(ptr, q); + } + }; + return std::unique_ptr(sycl::malloc_device(num_elts, q), Deleter(q)); +} + +// SYCL buffer creation helper. +template +sycl::buffer make_buffer(const vec &v) { + sycl::buffer buf(v.data(), sycl::range<1>(v.size())); + return buf; +} + +template +struct set_fp_value { + inline fpType operator()(fpType real, fpType /*imag*/) { + return real; + } +}; + +template +struct set_fp_value> { + inline auto operator()(scalarType real, scalarType imag) { + return std::complex(real, imag); + } +}; + +template +struct rand_scalar { + inline fpType operator()(double min, double max) { + return (fpType(std::rand()) / fpType(RAND_MAX)) * fpType(max - min) + fpType(min); + } +}; + +template +struct rand_scalar> { + inline std::complex operator()(double min, double max) { + rand_scalar rand; + return std::complex(rand(min, max), rand(min, max)); + } +}; + +template +void rand_vector(std::vector &v, std::size_t n) { + using fpRealType = typename complex_info::real_type; + v.resize(n); + rand_scalar rand; + for (std::size_t i = 0; i < n; i++) { + v[i] = rand(fpRealType(-0.5), fpRealType(0.5)); + } +} + +template +void rand_matrix(std::vector &m, oneapi::mkl::layout layout_val, std::size_t nrows, + std::size_t ncols, std::size_t ld) { + using fpRealType = typename complex_info::real_type; + std::size_t outer_size = nrows; + std::size_t inner_size = ncols; + if (layout_val == oneapi::mkl::layout::col_major) { + std::swap(outer_size, inner_size); + } + m.resize(outer_size * ld); + rand_scalar rand; + for (std::size_t i = 0; i < outer_size; ++i) { + std::size_t j = 0; + for (; j < inner_size; ++j) { + m[i * ld + j] = rand(fpRealType(-0.5), fpRealType(0.5)); + } + for (; j < ld; ++j) { + m[i * ld + j] = set_fp_value()(-1.f, 0.f); + } + } +} + +// Creating the 3arrays CSR representation (ia, ja, values) +// of general random sparse matrix +// with density (0 < density <= 1.0) +// -0.5 <= value < 0.5 +// require_diagonal means all diagonal entries guaranteed to be nonzero +template +intType generate_random_matrix(const intType nrows, const intType ncols, const double density_val, + intType indexing, std::vector &ia, std::vector &ja, + std::vector &a, bool require_diagonal = false) { + intType nnz = 0; + rand_scalar rand_density; + rand_scalar rand_data; + + ia.push_back(indexing); // starting index of row0. + for (intType i = 0; i < nrows; i++) { + ia.push_back(nnz + indexing); // ending index of row_i. + for (intType j = 0; j < ncols; j++) { + const bool is_diag = require_diagonal && i == j; + if (is_diag || (rand_density(0.0, 1.0) <= density_val)) { + fpType val; + if (is_diag) { + // Guarantee an amplitude >= 0.1 + fpType sign = (std::rand() % 2) * 2 - 1; + val = rand_data(0.1, 0.5) * sign; + } + else { + val = rand_data(-0.5, 0.5); + } + a.push_back(val); + ja.push_back(j + indexing); + nnz++; + } + } + ia[static_cast(i) + 1] = nnz + indexing; + } + return nnz; +} + +// Shuffle the 3arrays CSR representation (ia, ja, values) +// of any sparse matrix and set values serially from 0..nnz. +// Intended for use with sorting. +template +void shuffle_data(const intType *ia, intType *ja, fpType *a, const std::size_t nrows) { + // + // shuffle indices according to random seed + // + intType indexing = ia[0]; + for (std::size_t i = 0; i < nrows; ++i) { + intType nnz_row = ia[i + 1] - ia[i]; + for (intType j = ia[i] - indexing; j < ia[i + 1] - indexing; ++j) { + intType q = ia[i] - indexing + std::rand() % (nnz_row); + // swap element i and q + std::swap(ja[q], ja[j]); + std::swap(a[q], a[j]); + } + } +} + +inline void wait_and_free(sycl::queue &main_queue, oneapi::mkl::sparse::matrix_handle_t *p_handle) { + main_queue.wait(); + sycl::event ev_release; + CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_matrix_handle, main_queue, p_handle); + ev_release.wait(); +} + +template +bool check_equal(fpType x, fpType x_ref, double abs_error_margin, double rel_error_margin, + std::ostream &out) { + using fpRealType = typename complex_info::real_type; + static_assert(std::is_floating_point_v, + "Expected floating-point real or complex type."); + + const fpRealType epsilon = std::numeric_limits::epsilon(); + const auto abs_bound = static_cast(abs_error_margin) * epsilon; + const auto rel_bound = static_cast(rel_error_margin) * epsilon; + + const auto aerr = std::abs(x - x_ref); + const auto rerr = aerr / std::abs(x_ref); + const bool valid = (rerr <= rel_bound) || (aerr <= abs_bound); + if (!valid) { + out << "Mismatching results: actual = " << x << " vs. reference = " << x_ref << "\n"; + out << " relative error = " << rerr << " absolute error = " << aerr + << " relative bound = " << rel_bound << " absolute bound = " << abs_bound << "\n"; + } + return valid; +} + +template +bool check_equal_vector(const vecType1 &v, const vecType2 &v_ref, double abs_error_factor = 10.0, + double rel_error_factor = 200.0, std::ostream &out = std::cout) { + using T = typename vecType2::value_type; + std::size_t n = v.size(); + if (n != v_ref.size()) { + out << "Mismatching size got " << n << " expected " << v_ref.size() << "\n"; + return false; + } + if (n == 0) { + return true; + } + + auto max_norm_ref = + *std::max_element(std::begin(v_ref), std::end(v_ref), + [](const T &a, const T &b) { return std::abs(a) < std::abs(b); }); + // Heuristic for the average-case error margins + double abs_error_margin = + abs_error_factor * std::abs(max_norm_ref) * std::log2(static_cast(n)); + double rel_error_margin = rel_error_factor * std::log2(static_cast(n)); + + constexpr int max_print = 20; + int count = 0; + bool valid = true; + + for (std::size_t i = 0; i < n; ++i) { + // Allow to convert the unsigned index `i` to a signed one to keep this function generic and allow for `v` and `v_ref` to be a vector, a pointer or a random access iterator. +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wsign-conversion" + auto res = v[i]; + auto ref = v_ref[i]; +#pragma clang diagnostic pop + if (!check_equal(res, ref, abs_error_margin, rel_error_margin, out)) { + out << " at index i =" << i << "\n"; + valid = false; + ++count; + if (count > max_print) { + return valid; + } + } + } + + return valid; +} + +#endif // _TEST_COMMON_HPP__ diff --git a/tests/unit_tests/sparse_blas/source/CMakeLists.txt b/tests/unit_tests/sparse_blas/source/CMakeLists.txt new file mode 100644 index 000000000..3a1fcb288 --- /dev/null +++ b/tests/unit_tests/sparse_blas/source/CMakeLists.txt @@ -0,0 +1,63 @@ +#=============================================================================== +# Copyright 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +# +# +# SPDX-License-Identifier: Apache-2.0 +#=============================================================================== + +set(SPBLAS_SOURCES + "sparse_gemm_buffer.cpp" + "sparse_gemm_usm.cpp" + "sparse_gemv_buffer.cpp" + "sparse_gemv_usm.cpp" + "sparse_trsv_buffer.cpp" + "sparse_trsv_usm.cpp" +) + +include(WarningsUtils) + +if (BUILD_SHARED_LIBS) + add_library(spblas_source_rt OBJECT ${SPBLAS_SOURCES}) + target_compile_options(spblas_source_rt PRIVATE -DCALL_RT_API -DNOMINMAX) + target_include_directories(spblas_source_rt + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + ) + if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET spblas_source_rt SOURCES ${SPBLAS_SOURCES}) + else () + target_link_libraries(spblas_source_rt PUBLIC ONEMKL::SYCL::SYCL) + endif () + target_link_libraries(spblas_source_rt PRIVATE onemkl_warnings) +endif () + +add_library(spblas_source_ct OBJECT ${SPBLAS_SOURCES}) +target_compile_options(spblas_source_ct PRIVATE -DNOMINMAX) +target_include_directories(spblas_source_ct + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../include + PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/../../include + PUBLIC ${PROJECT_SOURCE_DIR}/include + PUBLIC ${PROJECT_SOURCE_DIR}/deps/googletest/include + PUBLIC ${CMAKE_BINARY_DIR}/bin + ) +if (USE_ADD_SYCL_TO_TARGET_INTEGRATION) + add_sycl_to_target(TARGET spblas_source_ct SOURCES ${SPBLAS_SOURCES}) +else () + target_link_libraries(spblas_source_ct PUBLIC ONEMKL::SYCL::SYCL) +endif () +target_link_libraries(spblas_source_ct PRIVATE onemkl_warnings) diff --git a/tests/unit_tests/sparse_blas/source/sparse_gemm_buffer.cpp b/tests/unit_tests/sparse_blas/source/sparse_gemm_buffer.cpp new file mode 100644 index 000000000..1c9549fcc --- /dev/null +++ b/tests/unit_tests/sparse_blas/source/sparse_gemm_buffer.cpp @@ -0,0 +1,302 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl.hpp" +#include "oneapi/mkl/detail/config.hpp" +#include "sparse_reference.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +extern std::vector devices; + +namespace { + +template +int test(sycl::device *dev, intType nrows_A, intType ncols_A, intType ncols_C, + double density_A_matrix, oneapi::mkl::index_base index, + oneapi::mkl::layout dense_matrix_layout, oneapi::mkl::transpose transpose_A, + oneapi::mkl::transpose transpose_B, fpType alpha, fpType beta, intType ldb, intType ldc, + bool opt_1_input, bool opt_2_inputs) { + sycl::queue main_queue(*dev, exception_handler_t()); + + intType int_index = (index == oneapi::mkl::index_base::zero) ? 0 : 1; + std::size_t opa_nrows = static_cast( + transpose_A == oneapi::mkl::transpose::nontrans ? nrows_A : ncols_A); + std::size_t opa_ncols = static_cast( + transpose_A == oneapi::mkl::transpose::nontrans ? ncols_A : nrows_A); + + // Input matrix in CSR format + std::vector ia_host, ja_host; + std::vector a_host; + intType nnz = generate_random_matrix(nrows_A, ncols_A, density_A_matrix, + int_index, ia_host, ja_host, a_host); + + // Input and output dense vectors + std::vector b_host, c_host; + rand_matrix(b_host, dense_matrix_layout, opa_ncols, static_cast(ncols_C), + static_cast(ldb)); + rand_matrix(c_host, dense_matrix_layout, opa_nrows, static_cast(ncols_C), + static_cast(ldc)); + std::vector c_ref_host(c_host); + + // Shuffle ordering of column indices/values to test sortedness + shuffle_data(ia_host.data(), ja_host.data(), a_host.data(), static_cast(nrows_A)); + + auto ia_buf = make_buffer(ia_host); + auto ja_buf = make_buffer(ja_host); + auto a_buf = make_buffer(a_host); + auto b_buf = make_buffer(b_host); + auto c_buf = make_buffer(c_host); + + sycl::event ev_release; + oneapi::mkl::sparse::matrix_handle_t handle = nullptr; + try { + CALL_RT_OR_CT(oneapi::mkl::sparse::init_matrix_handle, main_queue, &handle); + + CALL_RT_OR_CT(oneapi::mkl::sparse::set_csr_data, main_queue, handle, nrows_A, ncols_A, nnz, + index, ia_buf, ja_buf, a_buf); + + if (opt_1_input) { + CALL_RT_OR_CT(oneapi::mkl::sparse::optimize_gemm, main_queue, transpose_A, handle); + } + + if (opt_2_inputs) { + CALL_RT_OR_CT(oneapi::mkl::sparse::optimize_gemm, main_queue, transpose_A, transpose_B, + dense_matrix_layout, static_cast(ncols_C), handle); + } + + CALL_RT_OR_CT(oneapi::mkl::sparse::gemm, main_queue, dense_matrix_layout, transpose_A, + transpose_B, alpha, handle, b_buf, ncols_C, ldb, beta, c_buf, ldc); + + CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_matrix_handle, main_queue, &handle); + } + catch (const sycl::exception &e) { + std::cout << "Caught synchronous SYCL exception during sparse GEMV:\n" + << e.what() << std::endl; + print_error_code(e); + return 0; + } + catch (const oneapi::mkl::unimplemented &e) { + wait_and_free(main_queue, &handle); + return test_skipped; + } + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of sparse GEMV:\n" << error.what() << std::endl; + return 0; + } + + // Compute reference. + prepare_reference_gemm_data(ia_host.data(), ja_host.data(), a_host.data(), nrows_A, ncols_A, + ncols_C, nnz, int_index, dense_matrix_layout, transpose_A, + transpose_B, alpha, beta, ldb, ldc, b_host.data(), + c_ref_host.data()); + + // Compare the results of reference implementation and DPC++ implementation. + auto c_acc = c_buf.get_host_access(sycl::read_only); + bool valid = check_equal_vector(c_acc, c_ref_host); + + ev_release.wait_and_throw(); + return static_cast(valid); +} + +class SparseGemmBufferTests : public ::testing::TestWithParam {}; + +/** + * Helper function to run tests in different configuration. + * + * @tparam fpType Complex or scalar, single or double precision type + * @param dev Device to test + * @param transpose_A Transpose value for the A matrix + * @param transpose_B Transpose value for the B matrix + * @param num_passed Increase the number of configurations passed + * @param num_skipped Increase the number of configurations skipped + */ +template +void test_helper(sycl::device *dev, oneapi::mkl::transpose transpose_A, + oneapi::mkl::transpose transpose_B, int &num_passed, int &num_skipped) { + double density_A_matrix = 0.8; + fpType fp_zero = set_fp_value()(0.f, 0.f); + fpType fp_one = set_fp_value()(1.f, 0.f); + oneapi::mkl::index_base index_zero = oneapi::mkl::index_base::zero; + oneapi::mkl::layout col_major = oneapi::mkl::layout::col_major; + int nrows_A = 4, ncols_A = 6, ncols_C = 5; + int ldb = transpose_A == oneapi::mkl::transpose::nontrans ? ncols_A : nrows_A; + int ldc = transpose_A == oneapi::mkl::transpose::nontrans ? nrows_A : ncols_A; + bool no_opt_1_input = false; + bool opt_2_inputs = true; + + // Basic test + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test index_base 1 + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, oneapi::mkl::index_base::one, + col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc, no_opt_1_input, + opt_2_inputs), + num_passed, num_skipped); + // Test non-default alpha + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, set_fp_value()(2.f, 1.5f), fp_zero, ldb, ldc, no_opt_1_input, + opt_2_inputs), + num_passed, num_skipped); + // Test non-default beta + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, set_fp_value()(3.2f, 1.f), ldb, ldc, no_opt_1_input, + opt_2_inputs), + num_passed, num_skipped); + // Test 0 alpha + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_zero, fp_one, ldb, ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test 0 alpha and beta + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_zero, fp_zero, ldb, ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test non-default ldb + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb + 5, ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test non-default ldc + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc + 6, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test row major layout + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, + oneapi::mkl::layout::row_major, transpose_A, transpose_B, fp_one, fp_zero, ncols_C, + ncols_C, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test int64 indices + long long_nrows_A = 27, long_ncols_A = 13, long_ncols_C = 6; + long long_ldb = transpose_A == oneapi::mkl::transpose::nontrans ? long_ncols_A : long_nrows_A; + long long_ldc = transpose_A == oneapi::mkl::transpose::nontrans ? long_nrows_A : long_ncols_A; + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, long_nrows_A, long_ncols_A, long_ncols_C, density_A_matrix, + index_zero, col_major, transpose_A, transpose_B, fp_one, + fp_zero, long_ldb, long_ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Use optimize_gemm with only the sparse gemm input + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc, true, false), + num_passed, num_skipped); + // Use the 2 optimize_gemm versions + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc, true, true), + num_passed, num_skipped); + // Do not use optimize_gemm + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc, false, false), + num_passed, num_skipped); +} + +/** + * Helper function to test combination of transpose vals. + * Only test \p conjtrans if \p fpType is complex. + * + * @tparam fpType Complex or scalar, single or double precision type + * @param dev Device to test + * @param num_passed Increase the number of configurations passed + * @param num_skipped Increase the number of configurations skipped + */ +template +void test_helper_transpose(sycl::device *dev, int &num_passed, int &num_skipped) { + std::vector transpose_vals{ oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::trans }; + if (complex_info::is_complex) { + transpose_vals.push_back(oneapi::mkl::transpose::conjtrans); + } + for (auto transpose_A : transpose_vals) { + for (auto transpose_B : transpose_vals) { + test_helper(dev, transpose_A, transpose_B, num_passed, num_skipped); + } + } +} + +TEST_P(SparseGemmBufferTests, RealSinglePrecision) { + using fpType = float; + int num_passed = 0, num_skipped = 0; + test_helper_transpose(GetParam(), num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemmBufferTests, RealDoublePrecision) { + using fpType = double; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper_transpose(GetParam(), num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemmBufferTests, ComplexSinglePrecision) { + using fpType = std::complex; + int num_passed = 0, num_skipped = 0; + test_helper_transpose(GetParam(), num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemmBufferTests, ComplexDoublePrecision) { + using fpType = std::complex; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper_transpose(GetParam(), num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +INSTANTIATE_TEST_SUITE_P(SparseGemmBufferTestSuite, SparseGemmBufferTests, + testing::ValuesIn(devices), ::DeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/sparse_blas/source/sparse_gemm_usm.cpp b/tests/unit_tests/sparse_blas/source/sparse_gemm_usm.cpp new file mode 100644 index 000000000..3850f3b99 --- /dev/null +++ b/tests/unit_tests/sparse_blas/source/sparse_gemm_usm.cpp @@ -0,0 +1,330 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl.hpp" +#include "oneapi/mkl/detail/config.hpp" +#include "sparse_reference.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +extern std::vector devices; + +namespace { + +template +int test(sycl::device *dev, intType nrows_A, intType ncols_A, intType ncols_C, + double density_A_matrix, oneapi::mkl::index_base index, + oneapi::mkl::layout dense_matrix_layout, oneapi::mkl::transpose transpose_A, + oneapi::mkl::transpose transpose_B, fpType alpha, fpType beta, intType ldb, intType ldc, + bool opt_1_input, bool opt_2_inputs) { + sycl::queue main_queue(*dev, exception_handler_t()); + + intType int_index = (index == oneapi::mkl::index_base::zero) ? 0 : 1; + std::size_t opa_nrows = static_cast( + transpose_A == oneapi::mkl::transpose::nontrans ? nrows_A : ncols_A); + std::size_t opa_ncols = static_cast( + transpose_A == oneapi::mkl::transpose::nontrans ? ncols_A : nrows_A); + + // Input matrix in CSR format + std::vector ia_host, ja_host; + std::vector a_host; + intType nnz = generate_random_matrix(nrows_A, ncols_A, density_A_matrix, + int_index, ia_host, ja_host, a_host); + + // Input and output dense vectors + std::vector b_host, c_host; + rand_matrix(b_host, dense_matrix_layout, opa_ncols, static_cast(ncols_C), + static_cast(ldb)); + rand_matrix(c_host, dense_matrix_layout, opa_nrows, static_cast(ncols_C), + static_cast(ldc)); + std::vector c_ref_host(c_host); + + // Shuffle ordering of column indices/values to test sortedness + shuffle_data(ia_host.data(), ja_host.data(), a_host.data(), static_cast(nrows_A)); + + auto ia_usm_uptr = malloc_device_uptr(main_queue, ia_host.size()); + auto ja_usm_uptr = malloc_device_uptr(main_queue, ja_host.size()); + auto a_usm_uptr = malloc_device_uptr(main_queue, a_host.size()); + auto b_usm_uptr = malloc_device_uptr(main_queue, b_host.size()); + auto c_usm_uptr = malloc_device_uptr(main_queue, c_host.size()); + + intType *ia_usm = ia_usm_uptr.get(); + intType *ja_usm = ja_usm_uptr.get(); + fpType *a_usm = a_usm_uptr.get(); + fpType *b_usm = b_usm_uptr.get(); + fpType *c_usm = c_usm_uptr.get(); + + std::vector mat_dependencies; + std::vector gemm_dependencies; + // Copy host to device + mat_dependencies.push_back( + main_queue.memcpy(ia_usm, ia_host.data(), ia_host.size() * sizeof(intType))); + mat_dependencies.push_back( + main_queue.memcpy(ja_usm, ja_host.data(), ja_host.size() * sizeof(intType))); + mat_dependencies.push_back( + main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType))); + gemm_dependencies.push_back( + main_queue.memcpy(b_usm, b_host.data(), b_host.size() * sizeof(fpType))); + gemm_dependencies.push_back( + main_queue.memcpy(c_usm, c_host.data(), c_host.size() * sizeof(fpType))); + + sycl::event ev_copy, ev_release; + oneapi::mkl::sparse::matrix_handle_t handle = nullptr; + try { + sycl::event event; + CALL_RT_OR_CT(oneapi::mkl::sparse::init_matrix_handle, main_queue, &handle); + + CALL_RT_OR_CT(event = oneapi::mkl::sparse::set_csr_data, main_queue, handle, nrows_A, + ncols_A, nnz, index, ia_usm, ja_usm, a_usm, mat_dependencies); + + if (opt_1_input) { + CALL_RT_OR_CT(event = oneapi::mkl::sparse::optimize_gemm, main_queue, transpose_A, + handle, { event }); + } + + if (opt_2_inputs) { + CALL_RT_OR_CT(event = oneapi::mkl::sparse::optimize_gemm, main_queue, transpose_A, + transpose_B, dense_matrix_layout, static_cast(ncols_C), + handle, { event }); + } + + gemm_dependencies.push_back(event); + CALL_RT_OR_CT(event = oneapi::mkl::sparse::gemm, main_queue, dense_matrix_layout, + transpose_A, transpose_B, alpha, handle, b_usm, ncols_C, ldb, beta, c_usm, + ldc, gemm_dependencies); + + CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_matrix_handle, main_queue, &handle, + { event }); + + ev_copy = main_queue.memcpy(c_host.data(), c_usm, c_host.size() * sizeof(fpType), event); + } + catch (const sycl::exception &e) { + std::cout << "Caught synchronous SYCL exception during sparse GEMV:\n" + << e.what() << std::endl; + print_error_code(e); + return 0; + } + catch (const oneapi::mkl::unimplemented &e) { + wait_and_free(main_queue, &handle); + return test_skipped; + } + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of sparse GEMV:\n" << error.what() << std::endl; + return 0; + } + + // Compute reference. + prepare_reference_gemm_data(ia_host.data(), ja_host.data(), a_host.data(), nrows_A, ncols_A, + ncols_C, nnz, int_index, dense_matrix_layout, transpose_A, + transpose_B, alpha, beta, ldb, ldc, b_host.data(), + c_ref_host.data()); + + // Compare the results of reference implementation and DPC++ implementation. + ev_copy.wait_and_throw(); + bool valid = check_equal_vector(c_host, c_ref_host); + + ev_release.wait_and_throw(); + return static_cast(valid); +} + +class SparseGemmUsmTests : public ::testing::TestWithParam {}; + +/** + * Helper function to run tests in different configuration. + * + * @tparam fpType Complex or scalar, single or double precision type + * @param dev Device to test + * @param transpose_A Transpose value for the A matrix + * @param transpose_B Transpose value for the B matrix + * @param num_passed Increase the number of configurations passed + * @param num_skipped Increase the number of configurations skipped + */ +template +void test_helper(sycl::device *dev, oneapi::mkl::transpose transpose_A, + oneapi::mkl::transpose transpose_B, int &num_passed, int &num_skipped) { + double density_A_matrix = 0.8; + fpType fp_zero = set_fp_value()(0.f, 0.f); + fpType fp_one = set_fp_value()(1.f, 0.f); + oneapi::mkl::index_base index_zero = oneapi::mkl::index_base::zero; + oneapi::mkl::layout col_major = oneapi::mkl::layout::col_major; + int nrows_A = 4, ncols_A = 6, ncols_C = 5; + int ldb = transpose_A == oneapi::mkl::transpose::nontrans ? ncols_A : nrows_A; + int ldc = transpose_A == oneapi::mkl::transpose::nontrans ? nrows_A : ncols_A; + bool no_opt_1_input = false; + bool opt_2_inputs = true; + + // Basic test + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test index_base 1 + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, oneapi::mkl::index_base::one, + col_major, transpose_A, transpose_B, fp_one, fp_zero, ldb, ldc, no_opt_1_input, + opt_2_inputs), + num_passed, num_skipped); + // Test non-default alpha + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, set_fp_value()(2.f, 1.5f), fp_zero, ldb, ldc, no_opt_1_input, + opt_2_inputs), + num_passed, num_skipped); + // Test non-default beta + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, set_fp_value()(3.2f, 1.f), ldb, ldc, no_opt_1_input, + opt_2_inputs), + num_passed, num_skipped); + // Test 0 alpha + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_zero, fp_one, ldb, ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test 0 alpha and beta + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_zero, fp_zero, ldb, ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test non-default ldb + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb + 5, ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test non-default ldc + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc + 6, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test row major layout + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, + oneapi::mkl::layout::row_major, transpose_A, transpose_B, fp_one, fp_zero, ncols_C, + ncols_C, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Test int64 indices + long long_nrows_A = 27, long_ncols_A = 13, long_ncols_C = 6; + long long_ldb = transpose_A == oneapi::mkl::transpose::nontrans ? long_ncols_A : long_nrows_A; + long long_ldc = transpose_A == oneapi::mkl::transpose::nontrans ? long_nrows_A : long_ncols_A; + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, long_nrows_A, long_ncols_A, long_ncols_C, density_A_matrix, + index_zero, col_major, transpose_A, transpose_B, fp_one, + fp_zero, long_ldb, long_ldc, no_opt_1_input, opt_2_inputs), + num_passed, num_skipped); + // Use optimize_gemm with only the sparse gemm input + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc, true, false), + num_passed, num_skipped); + // Use the 2 optimize_gemm versions + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc, true, true), + num_passed, num_skipped); + // Do not use optimize_gemm + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, nrows_A, ncols_A, ncols_C, density_A_matrix, index_zero, col_major, transpose_A, + transpose_B, fp_one, fp_zero, ldb, ldc, false, false), + num_passed, num_skipped); +} + +/** + * Helper function to test combination of transpose vals. + * Only test \p conjtrans if \p fpType is complex. + * + * @tparam fpType Complex or scalar, single or double precision type + * @param dev Device to test + * @param num_passed Increase the number of configurations passed + * @param num_skipped Increase the number of configurations skipped + */ +template +auto test_helper_transpose(sycl::device *dev, int &num_passed, int &num_skipped) { + std::vector transpose_vals{ oneapi::mkl::transpose::nontrans, + oneapi::mkl::transpose::trans }; + if (complex_info::is_complex) { + transpose_vals.push_back(oneapi::mkl::transpose::conjtrans); + } + for (auto transpose_A : transpose_vals) { + for (auto transpose_B : transpose_vals) { + test_helper(dev, transpose_A, transpose_B, num_passed, num_skipped); + } + } +} + +TEST_P(SparseGemmUsmTests, RealSinglePrecision) { + using fpType = float; + int num_passed = 0, num_skipped = 0; + test_helper_transpose(GetParam(), num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemmUsmTests, RealDoublePrecision) { + using fpType = double; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper_transpose(GetParam(), num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemmUsmTests, ComplexSinglePrecision) { + using fpType = std::complex; + int num_passed = 0, num_skipped = 0; + test_helper_transpose(GetParam(), num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemmUsmTests, ComplexDoublePrecision) { + using fpType = std::complex; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper_transpose(GetParam(), num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +INSTANTIATE_TEST_SUITE_P(SparseGemmUsmTestSuite, SparseGemmUsmTests, testing::ValuesIn(devices), + ::DeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/sparse_blas/source/sparse_gemv_buffer.cpp b/tests/unit_tests/sparse_blas/source/sparse_gemv_buffer.cpp new file mode 100644 index 000000000..b95636831 --- /dev/null +++ b/tests/unit_tests/sparse_blas/source/sparse_gemv_buffer.cpp @@ -0,0 +1,230 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl.hpp" +#include "oneapi/mkl/detail/config.hpp" +#include "sparse_reference.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +extern std::vector devices; + +namespace { + +template +int test(sycl::device *dev, intType nrows, intType ncols, double density_A_matrix, + oneapi::mkl::index_base index, oneapi::mkl::transpose transpose_val, fpType alpha, + fpType beta, bool use_optimize) { + sycl::queue main_queue(*dev, exception_handler_t()); + + intType int_index = (index == oneapi::mkl::index_base::zero) ? 0 : 1; + std::size_t opa_nrows = + static_cast(transpose_val == oneapi::mkl::transpose::nontrans ? nrows : ncols); + std::size_t opa_ncols = + static_cast(transpose_val == oneapi::mkl::transpose::nontrans ? ncols : nrows); + + // Input matrix in CSR format + std::vector ia_host, ja_host; + std::vector a_host; + intType nnz = generate_random_matrix(nrows, ncols, density_A_matrix, int_index, + ia_host, ja_host, a_host); + + // Input and output dense vectors + // The input `x` and the input-output `y` are both initialized to random values on host and device. + std::vector x_host, y_host; + rand_vector(x_host, opa_ncols); + rand_vector(y_host, opa_nrows); + std::vector y_ref_host(y_host); + + // Shuffle ordering of column indices/values to test sortedness + shuffle_data(ia_host.data(), ja_host.data(), a_host.data(), static_cast(nrows)); + + auto ia_buf = make_buffer(ia_host); + auto ja_buf = make_buffer(ja_host); + auto a_buf = make_buffer(a_host); + auto x_buf = make_buffer(x_host); + auto y_buf = make_buffer(y_host); + + oneapi::mkl::sparse::matrix_handle_t handle = nullptr; + sycl::event ev_release; + try { + CALL_RT_OR_CT(oneapi::mkl::sparse::init_matrix_handle, main_queue, &handle); + + CALL_RT_OR_CT(oneapi::mkl::sparse::set_csr_data, main_queue, handle, nrows, ncols, nnz, + index, ia_buf, ja_buf, a_buf); + + if (use_optimize) { + CALL_RT_OR_CT(oneapi::mkl::sparse::optimize_gemv, main_queue, transpose_val, handle); + } + + CALL_RT_OR_CT(oneapi::mkl::sparse::gemv, main_queue, transpose_val, alpha, handle, x_buf, + beta, y_buf); + + CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_matrix_handle, main_queue, &handle); + } + catch (const sycl::exception &e) { + std::cout << "Caught synchronous SYCL exception during sparse GEMV:\n" + << e.what() << std::endl; + print_error_code(e); + return 0; + } + catch (const oneapi::mkl::unimplemented &e) { + wait_and_free(main_queue, &handle); + return test_skipped; + } + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of sparse GEMV:\n" << error.what() << std::endl; + return 0; + } + + // Compute reference. + prepare_reference_gemv_data(ia_host.data(), ja_host.data(), a_host.data(), nrows, ncols, nnz, + int_index, transpose_val, alpha, beta, x_host.data(), + y_ref_host.data()); + + // Compare the results of reference implementation and DPC++ implementation. + auto y_acc = y_buf.get_host_access(sycl::read_only); + bool valid = check_equal_vector(y_acc, y_ref_host); + + ev_release.wait_and_throw(); + return static_cast(valid); +} + +class SparseGemvBufferTests : public ::testing::TestWithParam {}; + +/** + * Helper function to run tests in different configuration. + * + * @tparam fpType Complex or scalar, single or double precision type + * @param dev Device to test + * @param transpose_val Transpose value for the input matrix + * @param num_passed Increase the number of configurations passed + * @param num_skipped Increase the number of configurations skipped + */ +template +void test_helper(sycl::device *dev, oneapi::mkl::transpose transpose_val, int &num_passed, + int &num_skipped) { + double density_A_matrix = 0.8; + fpType fp_zero = set_fp_value()(0.f, 0.f); + fpType fp_one = set_fp_value()(1.f, 0.f); + oneapi::mkl::index_base index_zero = oneapi::mkl::index_base::zero; + bool use_optimize = true; + + // Basic test + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, use_optimize), + num_passed, num_skipped); + // Test index_base 1 + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 4, 6, density_A_matrix, oneapi::mkl::index_base::one, + transpose_val, fp_one, fp_zero, use_optimize), + num_passed, num_skipped); + // Test non-default alpha + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, + set_fp_value()(2.f, 1.5f), fp_zero, use_optimize), + num_passed, num_skipped); + // Test non-default beta + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_one, + set_fp_value()(3.2f, 1.f), use_optimize), + num_passed, num_skipped); + // Test 0 alpha + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_zero, fp_one, use_optimize), + num_passed, num_skipped); + // Test 0 alpha and beta + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_zero, + fp_zero, use_optimize), + num_passed, num_skipped); + // Test int64 indices + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 27L, 13L, density_A_matrix, index_zero, transpose_val, + fp_one, fp_one, use_optimize), + num_passed, num_skipped); + // Test without optimize_gemv + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, false), + num_passed, num_skipped); +} + +TEST_P(SparseGemvBufferTests, RealSinglePrecision) { + using fpType = float; + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemvBufferTests, RealDoublePrecision) { + using fpType = double; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemvBufferTests, ComplexSinglePrecision) { + using fpType = std::complex; + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemvBufferTests, ComplexDoublePrecision) { + using fpType = std::complex; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +INSTANTIATE_TEST_SUITE_P(SparseGemvBufferTestSuite, SparseGemvBufferTests, + testing::ValuesIn(devices), ::DeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/sparse_blas/source/sparse_gemv_usm.cpp b/tests/unit_tests/sparse_blas/source/sparse_gemv_usm.cpp new file mode 100644 index 000000000..582e0c6f4 --- /dev/null +++ b/tests/unit_tests/sparse_blas/source/sparse_gemv_usm.cpp @@ -0,0 +1,256 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl.hpp" +#include "oneapi/mkl/detail/config.hpp" +#include "sparse_reference.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +extern std::vector devices; + +namespace { + +template +int test(sycl::device *dev, intType nrows, intType ncols, double density_A_matrix, + oneapi::mkl::index_base index, oneapi::mkl::transpose transpose_val, fpType alpha, + fpType beta, bool use_optimize) { + sycl::queue main_queue(*dev, exception_handler_t()); + + intType int_index = (index == oneapi::mkl::index_base::zero) ? 0 : 1; + std::size_t opa_nrows = + static_cast(transpose_val == oneapi::mkl::transpose::nontrans ? nrows : ncols); + std::size_t opa_ncols = + static_cast(transpose_val == oneapi::mkl::transpose::nontrans ? ncols : nrows); + + // Input matrix in CSR format + std::vector ia_host, ja_host; + std::vector a_host; + intType nnz = generate_random_matrix(nrows, ncols, density_A_matrix, int_index, + ia_host, ja_host, a_host); + + // Input and output dense vectors + // The input `x` and the input-output `y` are both initialized to random values on host and device. + std::vector x_host, y_host; + rand_vector(x_host, opa_ncols); + rand_vector(y_host, opa_nrows); + std::vector y_ref_host(y_host); + + // Shuffle ordering of column indices/values to test sortedness + shuffle_data(ia_host.data(), ja_host.data(), a_host.data(), static_cast(nrows)); + + auto ia_usm_uptr = malloc_device_uptr(main_queue, ia_host.size()); + auto ja_usm_uptr = malloc_device_uptr(main_queue, ja_host.size()); + auto a_usm_uptr = malloc_device_uptr(main_queue, a_host.size()); + auto x_usm_uptr = malloc_device_uptr(main_queue, x_host.size()); + auto y_usm_uptr = malloc_device_uptr(main_queue, y_host.size()); + + intType *ia_usm = ia_usm_uptr.get(); + intType *ja_usm = ja_usm_uptr.get(); + fpType *a_usm = a_usm_uptr.get(); + fpType *x_usm = x_usm_uptr.get(); + fpType *y_usm = y_usm_uptr.get(); + + std::vector mat_dependencies; + std::vector gemv_dependencies; + // Copy host to device + mat_dependencies.push_back( + main_queue.memcpy(ia_usm, ia_host.data(), ia_host.size() * sizeof(intType))); + mat_dependencies.push_back( + main_queue.memcpy(ja_usm, ja_host.data(), ja_host.size() * sizeof(intType))); + mat_dependencies.push_back( + main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType))); + gemv_dependencies.push_back( + main_queue.memcpy(x_usm, x_host.data(), x_host.size() * sizeof(fpType))); + gemv_dependencies.push_back( + main_queue.memcpy(y_usm, y_host.data(), y_host.size() * sizeof(fpType))); + + sycl::event ev_copy, ev_release; + oneapi::mkl::sparse::matrix_handle_t handle = nullptr; + try { + sycl::event event; + CALL_RT_OR_CT(oneapi::mkl::sparse::init_matrix_handle, main_queue, &handle); + + CALL_RT_OR_CT(event = oneapi::mkl::sparse::set_csr_data, main_queue, handle, nrows, ncols, + nnz, index, ia_usm, ja_usm, a_usm, mat_dependencies); + + if (use_optimize) { + CALL_RT_OR_CT(event = oneapi::mkl::sparse::optimize_gemv, main_queue, transpose_val, + handle, { event }); + } + + gemv_dependencies.push_back(event); + CALL_RT_OR_CT(event = oneapi::mkl::sparse::gemv, main_queue, transpose_val, alpha, handle, + x_usm, beta, y_usm, gemv_dependencies); + + CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_matrix_handle, main_queue, &handle, + { event }); + + ev_copy = main_queue.memcpy(y_host.data(), y_usm, y_host.size() * sizeof(fpType), event); + } + catch (const sycl::exception &e) { + std::cout << "Caught synchronous SYCL exception during sparse GEMV:\n" + << e.what() << std::endl; + print_error_code(e); + return 0; + } + catch (const oneapi::mkl::unimplemented &e) { + wait_and_free(main_queue, &handle); + return test_skipped; + } + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of sparse GEMV:\n" << error.what() << std::endl; + return 0; + } + + // Compute reference. + prepare_reference_gemv_data(ia_host.data(), ja_host.data(), a_host.data(), nrows, ncols, nnz, + int_index, transpose_val, alpha, beta, x_host.data(), + y_ref_host.data()); + + // Compare the results of reference implementation and DPC++ implementation. + ev_copy.wait_and_throw(); + bool valid = check_equal_vector(y_host, y_ref_host); + + ev_release.wait_and_throw(); + return static_cast(valid); +} + +class SparseGemvUsmTests : public ::testing::TestWithParam {}; + +/** + * Helper function to run tests in different configuration. + * + * @tparam fpType Complex or scalar, single or double precision type + * @param dev Device to test + * @param transpose_val Transpose value for the input matrix + * @param num_passed Increase the number of configurations passed + * @param num_skipped Increase the number of configurations skipped + */ +template +void test_helper(sycl::device *dev, oneapi::mkl::transpose transpose_val, int &num_passed, + int &num_skipped) { + double density_A_matrix = 0.8; + fpType fp_zero = set_fp_value()(0.f, 0.f); + fpType fp_one = set_fp_value()(1.f, 0.f); + oneapi::mkl::index_base index_zero = oneapi::mkl::index_base::zero; + bool use_optimize = true; + + // Basic test + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, use_optimize), + num_passed, num_skipped); + // Test index_base 1 + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 4, 6, density_A_matrix, oneapi::mkl::index_base::one, + transpose_val, fp_one, fp_zero, use_optimize), + num_passed, num_skipped); + // Test non-default alpha + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, + set_fp_value()(2.f, 1.5f), fp_zero, use_optimize), + num_passed, num_skipped); + // Test non-default beta + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_one, + set_fp_value()(3.2f, 1.f), use_optimize), + num_passed, num_skipped); + // Test 0 alpha + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_zero, fp_one, use_optimize), + num_passed, num_skipped); + // Test 0 alpha and beta + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_zero, + fp_zero, use_optimize), + num_passed, num_skipped); + // Test int64 indices + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 27L, 13L, density_A_matrix, index_zero, transpose_val, + fp_one, fp_one, use_optimize), + num_passed, num_skipped); + // Test without optimize_gemv + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, 4, 6, density_A_matrix, index_zero, transpose_val, fp_one, fp_zero, false), + num_passed, num_skipped); +} + +TEST_P(SparseGemvUsmTests, RealSinglePrecision) { + using fpType = float; + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemvUsmTests, RealDoublePrecision) { + using fpType = double; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemvUsmTests, ComplexSinglePrecision) { + using fpType = std::complex; + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseGemvUsmTests, ComplexDoublePrecision) { + using fpType = std::complex; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +INSTANTIATE_TEST_SUITE_P(SparseGemvUsmTestSuite, SparseGemvUsmTests, testing::ValuesIn(devices), + ::DeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/sparse_blas/source/sparse_trsv_buffer.cpp b/tests/unit_tests/sparse_blas/source/sparse_trsv_buffer.cpp new file mode 100644 index 000000000..4e82ae1f0 --- /dev/null +++ b/tests/unit_tests/sparse_blas/source/sparse_trsv_buffer.cpp @@ -0,0 +1,240 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl.hpp" +#include "oneapi/mkl/detail/config.hpp" +#include "sparse_reference.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +extern std::vector devices; + +namespace { + +template +int test(sycl::device *dev, intType m, double density_A_matrix, oneapi::mkl::index_base index, + oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, + oneapi::mkl::diag diag_val, bool use_optimize) { + sycl::queue main_queue(*dev, exception_handler_t()); + + intType int_index = (index == oneapi::mkl::index_base::zero) ? 0 : 1; + const std::size_t mu = static_cast(m); + + // Input matrix in CSR format + std::vector ia_host, ja_host; + std::vector a_host; + // Always require values to be present in the diagonal of the sparse matrix. + // The values set in the matrix don't need to be 1s even if diag_val is unit. + const bool require_diagonal = true; + intType nnz = generate_random_matrix( + m, m, density_A_matrix, int_index, ia_host, ja_host, a_host, require_diagonal); + + // Input dense vector. + // The input `x` is initialized to random values on host and device. + std::vector x_host; + rand_vector(x_host, mu); + + // Output and reference dense vectors. + // They are both initialized with a dummy value to catch more errors. + std::vector y_host(mu, -2.0f); + std::vector y_ref_host(y_host); + + // Intel oneMKL does not support unsorted data if + // `sparse::optimize_trsv()` is not called first. + if (use_optimize) { + // Shuffle ordering of column indices/values to test sortedness + shuffle_data(ia_host.data(), ja_host.data(), a_host.data(), mu); + } + + auto ia_buf = make_buffer(ia_host); + auto ja_buf = make_buffer(ja_host); + auto a_buf = make_buffer(a_host); + auto x_buf = make_buffer(x_host); + auto y_buf = make_buffer(y_host); + + sycl::event ev_release; + oneapi::mkl::sparse::matrix_handle_t handle = nullptr; + try { + CALL_RT_OR_CT(oneapi::mkl::sparse::init_matrix_handle, main_queue, &handle); + + CALL_RT_OR_CT(oneapi::mkl::sparse::set_csr_data, main_queue, handle, m, m, nnz, index, + ia_buf, ja_buf, a_buf); + + if (use_optimize) { + CALL_RT_OR_CT(oneapi::mkl::sparse::optimize_trsv, main_queue, uplo_val, transpose_val, + diag_val, handle); + } + + CALL_RT_OR_CT(oneapi::mkl::sparse::trsv, main_queue, uplo_val, transpose_val, diag_val, + handle, x_buf, y_buf); + + CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_matrix_handle, main_queue, &handle); + } + catch (const sycl::exception &e) { + std::cout << "Caught synchronous SYCL exception during sparse TRSV:\n" + << e.what() << std::endl; + print_error_code(e); + return 0; + } + catch (const oneapi::mkl::unimplemented &e) { + wait_and_free(main_queue, &handle); + return test_skipped; + } + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of sparse TRSV:\n" << error.what() << std::endl; + return 0; + } + + // Compute reference. + prepare_reference_trsv_data(ia_host.data(), ja_host.data(), a_host.data(), m, int_index, + uplo_val, transpose_val, diag_val, x_host.data(), + y_ref_host.data()); + + // Compare the results of reference implementation and DPC++ implementation. + auto y_acc = y_buf.get_host_access(sycl::read_only); + bool valid = check_equal_vector(y_acc, y_ref_host); + + ev_release.wait_and_throw(); + return static_cast(valid); +} + +class SparseTrsvBufferTests : public ::testing::TestWithParam {}; + +/** + * Helper function to run tests in different configuration. + * + * @tparam fpType Complex or scalar, single or double precision type + * @param dev Device to test + * @param transpose_val Transpose value for the input matrix + * @param num_passed Increase the number of configurations passed + * @param num_skipped Increase the number of configurations skipped + */ +template +auto test_helper(sycl::device *dev, oneapi::mkl::transpose transpose_val, int &num_passed, + int &num_skipped) { + double density_A_matrix = 0.144; + oneapi::mkl::index_base index_zero = oneapi::mkl::index_base::zero; + oneapi::mkl::uplo lower = oneapi::mkl::uplo::lower; + oneapi::mkl::diag nonunit = oneapi::mkl::diag::nonunit; + int m = 277; + bool use_optimize = true; + + // Basic test + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, index_zero, lower, + transpose_val, nonunit, use_optimize), + num_passed, num_skipped); + // Test index_base 1 + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, oneapi::mkl::index_base::one, + lower, transpose_val, nonunit, use_optimize), + num_passed, num_skipped); + // Test upper triangular matrix + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, m, density_A_matrix, index_zero, oneapi::mkl::uplo::upper, transpose_val, + nonunit, use_optimize), + num_passed, num_skipped); + // Test unit diagonal matrix + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, index_zero, lower, + transpose_val, oneapi::mkl::diag::unit, use_optimize), + num_passed, num_skipped); + // Temporarily disable trsv using long indices on GPU + if (!dev->is_gpu()) { + // Test int64 indices + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 15L, density_A_matrix, index_zero, lower, + transpose_val, nonunit, use_optimize), + num_passed, num_skipped); + } + // Test lower without optimize_trsv + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, m, density_A_matrix, index_zero, lower, transpose_val, nonunit, false), + num_passed, num_skipped); + // Test upper without optimize_trsv + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, m, density_A_matrix, index_zero, oneapi::mkl::uplo::upper, transpose_val, + nonunit, false), + num_passed, num_skipped); +} + +TEST_P(SparseTrsvBufferTests, RealSinglePrecision) { + using fpType = float; + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseTrsvBufferTests, RealDoublePrecision) { + using fpType = double; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseTrsvBufferTests, ComplexSinglePrecision) { + using fpType = std::complex; + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseTrsvBufferTests, ComplexDoublePrecision) { + using fpType = std::complex; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +INSTANTIATE_TEST_SUITE_P(SparseTrsvBufferTestSuite, SparseTrsvBufferTests, + testing::ValuesIn(devices), ::DeviceNamePrint()); + +} // anonymous namespace diff --git a/tests/unit_tests/sparse_blas/source/sparse_trsv_usm.cpp b/tests/unit_tests/sparse_blas/source/sparse_trsv_usm.cpp new file mode 100644 index 000000000..8292395fb --- /dev/null +++ b/tests/unit_tests/sparse_blas/source/sparse_trsv_usm.cpp @@ -0,0 +1,261 @@ +/******************************************************************************* +* Copyright 2023 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions +* and limitations under the License. +* +* +* SPDX-License-Identifier: Apache-2.0 +*******************************************************************************/ + +#include +#include +#include + +#if __has_include() +#include +#else +#include +#endif + +#include "oneapi/mkl.hpp" +#include "oneapi/mkl/detail/config.hpp" +#include "sparse_reference.hpp" +#include "test_common.hpp" +#include "test_helper.hpp" + +#include + +extern std::vector devices; + +namespace { + +template +int test(sycl::device *dev, intType m, double density_A_matrix, oneapi::mkl::index_base index, + oneapi::mkl::uplo uplo_val, oneapi::mkl::transpose transpose_val, + oneapi::mkl::diag diag_val, bool use_optimize) { + sycl::queue main_queue(*dev, exception_handler_t()); + + intType int_index = (index == oneapi::mkl::index_base::zero) ? 0 : 1; + const std::size_t mu = static_cast(m); + + // Input matrix in CSR format + std::vector ia_host, ja_host; + std::vector a_host; + const bool require_diagonal = diag_val == oneapi::mkl::diag::nonunit; + intType nnz = generate_random_matrix( + m, m, density_A_matrix, int_index, ia_host, ja_host, a_host, require_diagonal); + + // Input dense vector. + // The input `x` is initialized to random values on host and device. + std::vector x_host; + rand_vector(x_host, mu); + + // Output and reference dense vectors. + // They are both initialized with a dummy value to catch more errors. + std::vector y_host(mu, -2.0f); + std::vector y_ref_host(y_host); + + // Intel oneMKL does not support unsorted data if + // `sparse::optimize_trsv()` is not called first. + if (use_optimize) { + // Shuffle ordering of column indices/values to test sortedness + shuffle_data(ia_host.data(), ja_host.data(), a_host.data(), mu); + } + + auto ia_usm_uptr = malloc_device_uptr(main_queue, ia_host.size()); + auto ja_usm_uptr = malloc_device_uptr(main_queue, ja_host.size()); + auto a_usm_uptr = malloc_device_uptr(main_queue, a_host.size()); + auto x_usm_uptr = malloc_device_uptr(main_queue, x_host.size()); + auto y_usm_uptr = malloc_device_uptr(main_queue, y_host.size()); + + intType *ia_usm = ia_usm_uptr.get(); + intType *ja_usm = ja_usm_uptr.get(); + fpType *a_usm = a_usm_uptr.get(); + fpType *x_usm = x_usm_uptr.get(); + fpType *y_usm = y_usm_uptr.get(); + + std::vector mat_dependencies; + std::vector trsv_dependencies; + // Copy host to device + mat_dependencies.push_back( + main_queue.memcpy(ia_usm, ia_host.data(), ia_host.size() * sizeof(intType))); + mat_dependencies.push_back( + main_queue.memcpy(ja_usm, ja_host.data(), ja_host.size() * sizeof(intType))); + mat_dependencies.push_back( + main_queue.memcpy(a_usm, a_host.data(), a_host.size() * sizeof(fpType))); + trsv_dependencies.push_back( + main_queue.memcpy(x_usm, x_host.data(), x_host.size() * sizeof(fpType))); + trsv_dependencies.push_back( + main_queue.memcpy(y_usm, y_host.data(), y_host.size() * sizeof(fpType))); + + sycl::event ev_copy, ev_release; + oneapi::mkl::sparse::matrix_handle_t handle = nullptr; + try { + sycl::event event; + CALL_RT_OR_CT(oneapi::mkl::sparse::init_matrix_handle, main_queue, &handle); + + CALL_RT_OR_CT(event = oneapi::mkl::sparse::set_csr_data, main_queue, handle, m, m, nnz, + index, ia_usm, ja_usm, a_usm, mat_dependencies); + + if (use_optimize) { + CALL_RT_OR_CT(event = oneapi::mkl::sparse::optimize_trsv, main_queue, uplo_val, + transpose_val, diag_val, handle, { event }); + } + + trsv_dependencies.push_back(event); + CALL_RT_OR_CT(event = oneapi::mkl::sparse::trsv, main_queue, uplo_val, transpose_val, + diag_val, handle, x_usm, y_usm, trsv_dependencies); + + CALL_RT_OR_CT(ev_release = oneapi::mkl::sparse::release_matrix_handle, main_queue, &handle, + { event }); + + ev_copy = main_queue.memcpy(y_host.data(), y_usm, y_host.size() * sizeof(fpType), event); + } + catch (const sycl::exception &e) { + std::cout << "Caught synchronous SYCL exception during sparse TRSV:\n" + << e.what() << std::endl; + print_error_code(e); + return 0; + } + catch (const oneapi::mkl::unimplemented &e) { + wait_and_free(main_queue, &handle); + return test_skipped; + } + catch (const std::runtime_error &error) { + std::cout << "Error raised during execution of sparse TRSV:\n" << error.what() << std::endl; + return 0; + } + + // Compute reference. + prepare_reference_trsv_data(ia_host.data(), ja_host.data(), a_host.data(), m, int_index, + uplo_val, transpose_val, diag_val, x_host.data(), + y_ref_host.data()); + + // Compare the results of reference implementation and DPC++ implementation. + ev_copy.wait_and_throw(); + bool valid = check_equal_vector(y_host, y_ref_host); + + ev_release.wait_and_throw(); + return static_cast(valid); +} + +class SparseTrsvUsmTests : public ::testing::TestWithParam {}; + +/** + * Helper function to run tests in different configuration. + * + * @tparam fpType Complex or scalar, single or double precision type + * @param dev Device to test + * @param transpose_val Transpose value for the input matrix + */ +template +void test_helper(sycl::device *dev, oneapi::mkl::transpose transpose_val, int &num_passed, + int &num_skipped) { + double density_A_matrix = 0.144; + oneapi::mkl::index_base index_zero = oneapi::mkl::index_base::zero; + oneapi::mkl::uplo lower = oneapi::mkl::uplo::lower; + oneapi::mkl::diag nonunit = oneapi::mkl::diag::nonunit; + int m = 277; + bool use_optimize = true; + + // Basic test + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, index_zero, lower, + transpose_val, nonunit, use_optimize), + num_passed, num_skipped); + // Test index_base 1 + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, oneapi::mkl::index_base::one, + lower, transpose_val, nonunit, use_optimize), + num_passed, num_skipped); + // Test upper triangular matrix + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, m, density_A_matrix, index_zero, oneapi::mkl::uplo::upper, transpose_val, + nonunit, use_optimize), + num_passed, num_skipped); + // Test unit diagonal matrix + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, m, density_A_matrix, index_zero, lower, + transpose_val, oneapi::mkl::diag::unit, use_optimize), + num_passed, num_skipped); + // Temporarily disable trsv using long indices on GPU + if (!dev->is_gpu()) { + // Test int64 indices + EXPECT_TRUE_OR_FUTURE_SKIP(test(dev, 15L, density_A_matrix, index_zero, lower, + transpose_val, nonunit, use_optimize), + num_passed, num_skipped); + } + // Test lower without optimize_trsv + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, m, density_A_matrix, index_zero, lower, transpose_val, nonunit, false), + num_passed, num_skipped); + // Test upper without optimize_trsv + EXPECT_TRUE_OR_FUTURE_SKIP( + test(dev, m, density_A_matrix, index_zero, oneapi::mkl::uplo::upper, transpose_val, + nonunit, false), + num_passed, num_skipped); +} + +TEST_P(SparseTrsvUsmTests, RealSinglePrecision) { + using fpType = float; + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseTrsvUsmTests, RealDoublePrecision) { + using fpType = double; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseTrsvUsmTests, ComplexSinglePrecision) { + using fpType = std::complex; + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +TEST_P(SparseTrsvUsmTests, ComplexDoublePrecision) { + using fpType = std::complex; + CHECK_DOUBLE_ON_DEVICE(GetParam()); + int num_passed = 0, num_skipped = 0; + test_helper(GetParam(), oneapi::mkl::transpose::nontrans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::trans, num_passed, num_skipped); + test_helper(GetParam(), oneapi::mkl::transpose::conjtrans, num_passed, num_skipped); + if (num_skipped > 0) { + // Mark that some tests were skipped + GTEST_SKIP() << "Passed: " << num_passed << ", Skipped: " << num_skipped + << " configurations." << std::endl; + } +} + +INSTANTIATE_TEST_SUITE_P(SparseTrsvUsmTestSuite, SparseTrsvUsmTests, testing::ValuesIn(devices), + ::DeviceNamePrint()); + +} // anonymous namespace diff --git a/third-party-programs/THIRD-PARTY-PROGRAMS b/third-party-programs/THIRD-PARTY-PROGRAMS index eb71666fd..fd462fa83 100644 --- a/third-party-programs/THIRD-PARTY-PROGRAMS +++ b/third-party-programs/THIRD-PARTY-PROGRAMS @@ -1,14 +1,82 @@ -oneMKL Third Party Programs File -This file contains the list of third party software ("third party programs") contained in the Intel software and their required notices and/or license terms. This third party software, even if included with the distribution of the Intel software, may be governed by separate license terms, including without limitation, third party license terms, other Intel software license terms, and open source software license terms. These separate license terms govern your use of the third party programs as set forth in the "third-party-programs.txt" or other similarly-named text file. +Intel® oneAPI Math Kernel Library (oneMKL) interfaces -Third party programs and their corresponding required notices and/or license terms are listed below. +This file contains the list of third party software (“third party programs”) +contained in the Intel software and their required notices and/or license terms. +This third party software, even if included with the distribution of the Intel +software, may be governed by separate license terms, including without limitation, +third party license terms, other Intel software license terms, and open source +software license terms. These separate license terms govern your use of the third +party programs as set forth in the “third-party-programs-binary.txt” or other similarly-named text file. -1. Googletest -Copyright 2008, Google Inc. -All rights reserved. +Third party programs and their corresponding required notices and/or license terms are listed below. -BSD-Like License +-------------------------------------------------------------- +1. rocRAND backend files + Copyright (C) 2022 Heidelberg University, Engineering Mathematics and Computing Lab (EMCL) + + cuRAND backend files + cuRAND back-end Copyright (c) 2021, The Regents of the University of + California, through Lawrence Berkeley National Laboratory (subject to receipt + of any required approvals from the U.S. Dept. of Energy). All rights + reserved. + + +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* (1) Redistributions of source code must retain the above copyright notice, +* this list of conditions and the following disclaimer. +* +* (2) Redistributions in binary form must reproduce the above copyright +* notice, this list of conditions and the following disclaimer in the +* documentation and/or other materials provided with the distribution. +* +* (3) Neither the name of the University of California, Lawrence Berkeley +* National Laboratory, U.S. Dept. of Energy nor the names of its contributors +* may be used to endorse or promote products derived from this software +* without specific prior written permission. +* +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +* ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +* LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +* SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +* INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +* CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +* ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +* POSSIBILITY OF SUCH DAMAGE. +* +* You are under no obligation whatsoever to provide any bug fixes, patches, +* or upgrades to the features, functionality or performance of the source +* code ("Enhancements") to anyone; however, if you choose to make your +* Enhancements available either publicly, or directly to Lawrence Berkeley +* National Laboratory, without imposing a separate written license agreement +* for such Enhancements, then you hereby grant the following license: a +* non-exclusive, royalty-free perpetual license to install, use, modify, +* prepare derivative works, incorporate into other computer software, +* distribute, and sublicense such enhancements or derivative works thereof, +* in binary and source code form. +* +* If you have questions about your rights to use or distribute this software, +* please contact Berkeley Lab's Intellectual Property Office at +* IPO@lbl.gov. +* +* NOTICE. This Software was developed under funding from the U.S. Department +* of Energy and the U.S. Government consequently retains certain rights. As +* such, the U.S. Government has been granted for itself and others acting on +* its behalf a paid-up, nonexclusive, irrevocable, worldwide license in the +* Software to reproduce, distribute copies to the public, prepare derivative +* works, and perform publicly and display publicly, and to permit others to do +* so. + +-------------------------------------------------------------- +2. Google C++ Testing Framework + Copyright 2008, Google Inc. + All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are @@ -36,808 +104,12 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +-------------------------------------------------------------- +3. Math.NET Numerics + Copyright (c) 2002-2022 Math.NET -2. CMake – Cross Platform Makefile Generator -Copyright 2000-2020 Kitware, Inc. and Contributors -All rights reserved. - -BSD-Like License - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions -are met: - -* Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -* Neither the name of Kitware, Inc. nor the names of Contributors - may be used to endorse or promote products derived from this - software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - ------------------------------------------------------------------------------- - -The following individuals and institutions are among the Contributors: - -* Aaron C. Meadows -* Adriaan de Groot -* Aleksey Avdeev -* Alexander Neundorf -* Alexander Smorkalov -* Alexey Sokolov -* Alex Merry -* Alex Turbov -* Andreas Pakulat -* Andreas Schneider -* André Rigland Brodtkorb -* Axel Huebl, Helmholtz-Zentrum Dresden - Rossendorf -* Benjamin Eikel -* Bjoern Ricks -* Brad Hards -* Christopher Harvey -* Christoph Grüninger -* Clement Creusot -* Daniel Blezek -* Daniel Pfeifer -* Enrico Scholz -* Eran Ifrah -* Esben Mose Hansen, Ange Optimization ApS -* Geoffrey Viola -* Google Inc -* Gregor Jasny -* Helio Chissini de Castro -* Ilya Lavrenov -* Insight Software Consortium -* Jan Woetzel -* Julien Schueller -* Kelly Thompson -* Laurent Montel -* Konstantin Podsvirov -* Mario Bensi -* Martin Gräßlin -* Mathieu Malaterre -* Matthaeus G. Chajdas -* Matthias Kretz -* Matthias Maennich -* Michael Hirsch, Ph.D. -* Michael Stürmer -* Miguel A. Figueroa-Villanueva -* Mike Jackson -* Mike McQuaid -* Nicolas Bock -* Nicolas Despres -* Nikita Krupen'ko -* NVIDIA Corporation -* OpenGamma Ltd. -* Patrick Stotko -* Per Øyvind Karlsen -* Peter Collingbourne -* Petr Gotthard -* Philip Lowman -* Philippe Proulx -* Raffi Enficiaud, Max Planck Society -* Raumfeld -* Roger Leigh -* Rolf Eike Beer -* Roman Donchenko -* Roman Kharitonov -* Ruslan Baratov -* Sebastian Holtermann -* Stephen Kelly -* Sylvain Joubert -* The Qt Company Ltd. -* Thomas Sondergaard -* Tobias Hunger -* Todd Gamblin -* Tristan Carel -* University of Dundee -* Vadim Zhukov -* Will Dicharry - -See version control history for details of individual contributions. - -The above copyright and license notice applies to distributions of -CMake in source and binary form. Third-party software packages supplied -with CMake under compatible licenses provide their own copyright notices -documented in corresponding subdirectories or source files. - ------------------------------------------------------------------------------- - -CMake was initially developed by Kitware with the following sponsorship: - - * National Library of Medicine at the National Institutes of Health - as part of the Insight Segmentation and Registration Toolkit (ITK). - - * US National Labs (Los Alamos, Livermore, Sandia) ASC Parallel - Visualization Initiative. - - * National Alliance for Medical Image Computing (NAMIC) is funded by the - National Institutes of Health through the NIH Roadmap for Medical Research, - Grant U54 EB005149. - - * Kitware, Inc. - - - -3. Ninja Build -Copyright 2011 Google Inc. All Rights Reserved. - -Apache License 2.0 - - Apache License - Version 2.0, January 2010 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - - -4. Intel(R) oneAPI Math Kernel Library - -Intel Simplified Software License (Version February 2020) - -Use and Redistribution. You may use and redistribute the software (the “Software”), without modification, provided the following conditions are met: - -Redistributions must reproduce the above copyright notice and the following terms of use in the Software and in the documentation and/or other materials provided with the distribution. -Neither the name of Intel nor the names of its suppliers may be used to endorse or promote products derived from this Software without specific prior written permission. -No reverse engineering, decompilation, or disassembly of this Software is permitted. -Limited patent license. Intel grants you a world-wide, royalty-free, non-exclusive license under patents it now or hereafter owns or controls to make, have made, use, import, offer to sell and sell (“Utilize”) this Software, but solely to the extent that any such patent is necessary to Utilize the Software alone. The patent license shall not apply to any combinations which include this software. No hardware per se is licensed hereunder. - -Third party programs. The Software may contain Third Party Programs. “Third Party Programs” are third party software, open source software or other Intel software listed in the “third-party-programs.txt” or other similarly named text file that is included with the Software. Third Party Programs, even if included with the distribution of the Software, may be governed by separate license terms, including without limitation, third party license terms, open source software notices and terms, and/or other Intel software license terms. These separate license terms may govern your use of the Third Party Programs. - -DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND ATTORNEYS’ FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT INTEL WAS NEGLIGENT REGARDING THE DESIGN OR MANUFACTURE OF THE MATERIALS. - -LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. YOU AGREE TO INDEMNIFY AND HOLD INTEL HARMLESS AGAINST ANY CLAIMS AND EXPENSES RESULTING FROM YOUR USE OR UNAUTHORIZED USE OF THE SOFTWARE. - -No support. Intel may make changes to the Software, at any time without notice, and is not obligated to support, update or provide training for the Software. - -Termination. Intel may terminate your right to use the Software in the event of your breach of this Agreement and you fail to cure the breach within a reasonable period of time. - -Feedback. Should you provide Intel with comments, modifications, corrections, enhancements or other input (“Feedback”) related to the Software Intel will be free to use, disclose, reproduce, license or otherwise distribute or exploit the Feedback in its sole discretion without any obligations or restrictions of any kind, including without limitation, intellectual property rights or licensing obligations. - -Compliance with laws. You agree to comply with all relevant laws and regulations governing your use, transfer, import or export (or prohibition thereof) of the Software. - -Governing law. All disputes will be governed by the laws of the United States of America and the State of Delaware without reference to conflict of law principles and subject to the exclusive jurisdiction of the state or federal courts sitting in the State of Delaware, and each party agrees that it submits to the personal jurisdiction and venue of those courts and waives any objections. The United Nations Convention on Contracts for the International Sale of Goods (1980) is specifically excluded and will not apply to the Software. - -*Other names and brands may be claimed as the property of others. - -5. GNU* Fortran Compiler - -GNU General Public License - -GNU General Public License -Version 3, 29 June 2007 -Copyright © 2007 Free Software Foundation, Inc. http://fsf.org/ - -Everyone is permitted to copy and distribute verbatim copies of this -license document, but changing it is not allowed. -Preamble -The GNU General Public License is a free, copyleft license for software and other kinds of works. -The licenses for most software and other practical works are designed to take away your freedom to share and change the works. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change all versions of a program–to make sure it remains free software for all its users. We, the Free Software Foundation, use the GNU General Public License for most of our software; it applies also to any other work released this way by its authors. You can apply it to your programs, too. -When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for them if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs, and that you know you can do these things. -To protect your rights, we need to prevent others from denying you these rights or asking you to surrender the rights. Therefore, you have certain responsibilities if you distribute copies of the software, or if you modify it: responsibilities to respect the freedom of others. -For example, if you distribute copies of such a program, whether gratis or for a fee, you must pass on to the recipients the same freedoms that you received. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. -Developers that use the GNU GPL protect your rights with two steps: (1) assert copyright on the software, and (2) offer you this License giving you legal permission to copy, distribute and/or modify it. -For the developers’ and authors’ protection, the GPL clearly explains that there is no warranty for this free software. For both users’ and authors’ sake, the GPL requires that modified versions be marked as changed, so that their problems will not be attributed erroneously to authors of previous versions. -Some devices are designed to deny users access to install or run modified versions of the software inside them, although the manufacturer can do so. This is fundamentally incompatible with the aim of protecting users’ freedom to change the software. The systematic pattern of such abuse occurs in the area of products for individuals to use, which is precisely where it is most unacceptable. Therefore, we have designed this version of the GPL to prohibit the practice for those products. If such problems arise substantially in other domains, we stand ready to extend this provision to those domains in future versions of the GPL, as needed to protect the freedom of users. -Finally, every program is threatened constantly by software patents. States should not allow patents to restrict development and use of software on general-purpose computers, but in those that do, we wish to avoid the special danger that patents applied to a free program could make it effectively proprietary. To prevent this, the GPL assures that patents cannot be used to render the program non-free. -The precise terms and conditions for copying, distribution and modification follow. -TERMS AND CONDITIONS -1. Definitions. -“This License” refers to version 3 of the GNU General Public License. -“Copyright” also means copyright-like laws that apply to other kinds of works, such as semiconductor masks. -“The Program” refers to any copyrightable work licensed under this License. Each licensee is addressed as “you”. “Licensees” and “recipients” may be individuals or organizations. -To “modify” a work means to copy from or adapt all or part of the work in a fashion requiring copyright permission, other than the making of an exact copy. The resulting work is called a “modified version” of the earlier work or a work “based on” the earlier work. -A “covered work” means either the unmodified Program or a work based on the Program. -To “propagate” a work means to do anything with it that, without permission, would make you directly or secondarily liable for infringement under applicable copyright law, except executing it on a computer or modifying a private copy. Propagation includes copying, distribution (with or without modification), making available to the public, and in some countries other activities as well. -To “convey” a work means any kind of propagation that enables other parties to make or receive copies. Mere interaction with a user through a computer network, with no transfer of a copy, is not conveying. -An interactive user interface displays “Appropriate Legal Notices” to the extent that it includes a convenient and prominently visible feature that (1) displays an appropriate copyright notice, and (2) tells the user that there is no warranty for the work (except to the extent that warranties are provided), that licensees may convey the work under this License, and how to view a copy of this License. If the interface presents a list of user commands or options, such as a menu, a prominent item in the list meets this criterion. -2. Source Code. -The “source code” for a work means the preferred form of the work for making modifications to it. “Object code” means any non-source form of a work. -A “Standard Interface” means an interface that either is an official standard defined by a recognized standards body, or, in the case of interfaces specified for a particular programming language, one that is widely used among developers working in that language. -The “System Libraries” of an executable work include anything, other than the work as a whole, that (a) is included in the normal form of packaging a Major Component, but which is not part of that Major Component, and (b) serves only to enable use of the work with that Major Component, or to implement a Standard Interface for which an implementation is available to the public in source code form. A “Major Component”, in this context, means a major essential component (kernel, window system, and so on) of the specific operating system (if any) on which the executable work runs, or a compiler used to produce the work, or an object code interpreter used to run it. -The “Corresponding Source” for a work in object code form means all the source code needed to generate, install, and (for an executable work) run the object code and to modify the work, including scripts to control those activities. However, it does not include the work’s System Libraries, or general-purpose tools or generally available free programs which are used unmodified in performing those activities but which are not part of the work. For example, Corresponding Source includes interface definition files associated with source files for the work, and the source code for shared libraries and dynamically linked subprograms that the work is specifically designed to require, such as by intimate data communication or control flow between those subprograms and other parts of the work. -The Corresponding Source need not include anything that users can regenerate automatically from other parts of the Corresponding Source. -The Corresponding Source for a work in source code form is that same work. -3. Basic Permissions. -All rights granted under this License are granted for the term of copyright on the Program, and are irrevocable provided the stated conditions are met. This License explicitly affirms your unlimited permission to run the unmodified Program. The output from running a covered work is covered by this License only if the output, given its content, constitutes a covered work. This License acknowledges your rights of fair use or other equivalent, as provided by copyright law. -You may make, run and propagate covered works that you do not convey, without conditions so long as your license otherwise remains in force. You may convey covered works to others for the sole purpose of having them make modifications exclusively for you, or provide you with facilities for running those works, provided that you comply with the terms of this License in conveying all material for which you do not control copyright. Those thus making or running the covered works for you must do so exclusively on your behalf, under your direction and control, on terms that prohibit them from making any copies of your copyrighted material outside their relationship with you. -Conveying under any other circumstances is permitted solely under the conditions stated below. Sublicensing is not allowed; section 10 makes it unnecessary. -4. Protecting Users’ Legal Rights From Anti-Circumvention Law. -No covered work shall be deemed part of an effective technological measure under any applicable law fulfilling obligations under article 11 of the WIPO copyright treaty adopted on 20 December 1996, or similar laws prohibiting or restricting circumvention of such measures. -When you convey a covered work, you waive any legal power to forbid circumvention of technological measures to the extent such circumvention is effected by exercising rights under this License with respect to the covered work, and you disclaim any intention to limit operation or modification of the work as a means of enforcing, against the work’s users, your or third parties’ legal rights to forbid circumvention of technological measures. -5. Conveying Verbatim Copies. -You may convey verbatim copies of the Program’s source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice; keep intact all notices stating that this License and any non-permissive terms added in accord with section 7 apply to the code; keep intact all notices of the absence of any warranty; and give all recipients a copy of this License along with the Program. -You may charge any price or no price for each copy that you convey, and you may offer support or warranty protection for a fee. -6. Conveying Modified Source Versions. -You may convey a work based on the Program, or the modifications to produce it from the Program, in the form of source code under the terms of section 4, provided that you also meet all of these conditions: -1. The work must carry prominent notices stating that you modified it, and giving a relevant date. -2. The work must carry prominent notices stating that it is released under this License and any conditions added under section 7. This requirement modifies the requirement in section 4 to “keep intact all notices”. -3. You must license the entire work, as a whole, under this License to anyone who comes into possession of a copy. This License will therefore apply, along with any applicable section 7 additional terms, to the whole of the work, and all its parts, regardless of how they are packaged. This License gives no permission to license the work in any other way, but it does not invalidate such permission if you have separately received it. -4. If the work has interactive user interfaces, each must display Appropriate Legal Notices; however, if the Program has interactive interfaces that do not display Appropriate Legal Notices, your work need not make them do so. -A compilation of a covered work with other separate and independent works, which are not by their nature extensions of the covered work, and which are not combined with it such as to form a larger program, in or on a volume of a storage or distribution medium, is called an “aggregate” if the compilation and its resulting copyright are not used to limit the access or legal rights of the compilation’s users beyond what the individual works permit. Inclusion of a covered work in an aggregate does not cause this License to apply to the other parts of the aggregate. -7. Conveying Non-Source Forms. -You may convey a covered work in object code form under the terms of sections 4 and 5, provided that you also convey the machine-readable Corresponding Source under the terms of this License, in one of these ways: -1. Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by the Corresponding Source fixed on a durable physical medium customarily used for software interchange. -2. Convey the object code in, or embodied in, a physical product (including a physical distribution medium), accompanied by a written offer, valid for at least three years and valid for as long as you offer spare parts or customer support for that product model, to give anyone who possesses the object code either (1) a copy of the Corresponding Source for all the software in the product that is covered by this License, on a durable physical medium customarily used for software interchange, for a price no more than your reasonable cost of physically performing this conveying of source, or (2) access to copy the Corresponding Source from a network server at no charge. -3. Convey individual copies of the object code with a copy of the written offer to provide the Corresponding Source. This alternative is allowed only occasionally and noncommercially, and only if you received the object code with such an offer, in accord with subsection 6b. -4. Convey the object code by offering access from a designated place (gratis or for a charge), and offer equivalent access to the Corresponding Source in the same way through the same place at no further charge. You need not require recipients to copy the Corresponding Source along with the object code. If the place to copy the object code is a network server, the Corresponding Source may be on a different server (operated by you or a third party) that supports equivalent copying facilities, provided you maintain clear directions next to the object code saying where to find the Corresponding Source. Regardless of what server hosts the Corresponding Source, you remain obligated to ensure that it is available for as long as needed to satisfy these requirements. -5. Convey the object code using peer-to-peer transmission, provided you inform other peers where the object code and Corresponding Source of the work are being offered to the general public at no charge under subsection 6d. -A separable portion of the object code, whose source code is excluded from the Corresponding Source as a System Library, need not be included in conveying the object code work. -A “User Product” is either (1) a “consumer product”, which means any tangible personal property which is normally used for personal, family, or household purposes, or (2) anything designed or sold for incorporation into a dwelling. In determining whether a product is a consumer product, doubtful cases shall be resolved in favor of coverage. For a particular product received by a particular user, “normally used” refers to a typical or common use of that class of product, regardless of the status of the particular user or of the way in which the particular user actually uses, or expects or is expected to use, the product. A product is a consumer product regardless of whether the product has substantial commercial, industrial or non-consumer uses, unless such uses represent the only significant mode of use of the product. -“Installation Information” for a User Product means any methods, procedures, authorization keys, or other information required to install and execute modified versions of a covered work in that User Product from a modified version of its Corresponding Source. The information must suffice to ensure that the continued functioning of the modified object code is in no case prevented or interfered with solely because modification has been made. -If you convey an object code work under this section in, or with, or specifically for use in, a User Product, and the conveying occurs as part of a transaction in which the right of possession and use of the User Product is transferred to the recipient in perpetuity or for a fixed term (regardless of how the transaction is characterized), the Corresponding Source conveyed under this section must be accompanied by the Installation Information. But this requirement does not apply if neither you nor any third party retains the ability to install modified object code on the User Product (for example, the work has been installed in ROM). -The requirement to provide Installation Information does not include a requirement to continue to provide support service, warranty, or updates for a work that has been modified or installed by the recipient, or for the User Product in which it has been modified or installed. Access to a network may be denied when the modification itself materially and adversely affects the operation of the network or violates the rules and protocols for communication across the network. -Corresponding Source conveyed, and Installation Information provided, in accord with this section must be in a format that is publicly documented (and with an implementation available to the public in source code form), and must require no special password or key for unpacking, reading or copying. -8. Additional Terms. -“Additional permissions” are terms that supplement the terms of this License by making exceptions from one or more of its conditions. Additional permissions that are applicable to the entire Program shall be treated as though they were included in this License, to the extent that they are valid under applicable law. If additional permissions apply only to part of the Program, that part may be used separately under those permissions, but the entire Program remains governed by this License without regard to the additional permissions. -When you convey a copy of a covered work, you may at your option remove any additional permissions from that copy, or from any part of it. (Additional permissions may be written to require their own removal in certain cases when you modify the work.) You may place additional permissions on material, added by you to a covered work, for which you have or can give appropriate copyright permission. -Notwithstanding any other provision of this License, for material you add to a covered work, you may (if authorized by the copyright holders of that material) supplement the terms of this License with terms: -1. Disclaiming warranty or limiting liability differently from the terms of sections 15 and 16 of this License; or -2. Requiring preservation of specified reasonable legal notices or author attributions in that material or in the Appropriate Legal Notices displayed by works containing it; or -3. Prohibiting misrepresentation of the origin of that material, or requiring that modified versions of such material be marked in reasonable ways as different from the original version; or -4. Limiting the use for publicity purposes of names of licensors or authors of the material; or -5. Declining to grant rights under trademark law for use of some trade names, trademarks, or service marks; or -6. Requiring indemnification of licensors and authors of that material by anyone who conveys the material (or modified versions of it) with contractual assumptions of liability to the recipient, for any liability that these contractual assumptions directly impose on those licensors and authors. -All other non-permissive additional terms are considered “further restrictions” within the meaning of section 10. If the Program as you received it, or any part of it, contains a notice stating that it is governed by this License along with a term that is a further restriction, you may remove that term. If a license document contains a further restriction but permits relicensing or conveying under this License, you may add to a covered work material governed by the terms of that license document, provided that the further restriction does not survive such relicensing or conveying. -If you add terms to a covered work in accord with this section, you must place, in the relevant source files, a statement of the additional terms that apply to those files, or a notice indicating where to find the applicable terms. -Additional terms, permissive or non-permissive, may be stated in the form of a separately written license, or stated as exceptions; the above requirements apply either way. -9. Termination. -You may not propagate or modify a covered work except as expressly provided under this License. Any attempt otherwise to propagate or modify it is void, and will automatically terminate your rights under this License (including any patent licenses granted under the third paragraph of section 11). -However, if you cease all violation of this License, then your license from a particular copyright holder is reinstated (a) provisionally, unless and until the copyright holder explicitly and finally terminates your license, and (b) permanently, if the copyright holder fails to notify you of the violation by some reasonable means prior to 60 days after the cessation. -Moreover, your license from a particular copyright holder is reinstated permanently if the copyright holder notifies you of the violation by some reasonable means, this is the first time you have received notice of violation of this License (for any work) from that copyright holder, and you cure the violation prior to 30 days after your receipt of the notice. -Termination of your rights under this section does not terminate the licenses of parties who have received copies or rights from you under this License. If your rights have been terminated and not permanently reinstated, you do not qualify to receive new licenses for the same material under section 10. -10. Acceptance Not Required for Having Copies. -You are not required to accept this License in order to receive or run a copy of the Program. Ancillary propagation of a covered work occurring solely as a consequence of using peer-to-peer transmission to receive a copy likewise does not require acceptance. However, nothing other than this License grants you permission to propagate or modify any covered work. These actions infringe copyright if you do not accept this License. Therefore, by modifying or propagating a covered work, you indicate your acceptance of this License to do so. -11. Automatic Licensing of Downstream Recipients. -Each time you convey a covered work, the recipient automatically receives a license from the original licensors, to run, modify and propagate that work, subject to this License. You are not responsible for enforcing compliance by third parties with this License. -An “entity transaction” is a transaction transferring control of an organization, or substantially all assets of one, or subdividing an organization, or merging organizations. If propagation of a covered work results from an entity transaction, each party to that transaction who receives a copy of the work also receives whatever licenses to the work the party’s predecessor in interest had or could give under the previous paragraph, plus a right to possession of the Corresponding Source of the work from the predecessor in interest, if the predecessor has it or can get it with reasonable efforts. -You may not impose any further restrictions on the exercise of the rights granted or affirmed under this License. For example, you may not impose a license fee, royalty, or other charge for exercise of rights granted under this License, and you may not initiate litigation (including a cross-claim or counterclaim in a lawsuit) alleging that any patent claim is infringed by making, using, selling, offering for sale, or importing the Program or any portion of it. -12. Patents. -A “contributor” is a copyright holder who authorizes use under this License of the Program or a work on which the Program is based. The work thus licensed is called the contributor’s “contributor version”. -A contributor’s “essential patent claims” are all patent claims owned or controlled by the contributor, whether already acquired or hereafter acquired, that would be infringed by some manner, permitted by this License, of making, using, or selling its contributor version, but do not include claims that would be infringed only as a consequence of further modification of the contributor version. For purposes of this definition, “control” includes the right to grant patent sublicenses in a manner consistent with the requirements of this License. -Each contributor grants you a non-exclusive, worldwide, royalty-free patent license under the contributor’s essential patent claims, to make, use, sell, offer for sale, import and otherwise run, modify and propagate the contents of its contributor version. -In the following three paragraphs, a “patent license” is any express agreement or commitment, however denominated, not to enforce a patent (such as an express permission to practice a patent or covenant not to sue for patent infringement). To “grant” such a patent license to a party means to make such an agreement or commitment not to enforce a patent against the party. -If you convey a covered work, knowingly relying on a patent license, and the Corresponding Source of the work is not available for anyone to copy, free of charge and under the terms of this License, through a publicly available network server or other readily accessible means, then you must either (1) cause the Corresponding Source to be so available, or (2) arrange to deprive yourself of the benefit of the patent license for this particular work, or (3) arrange, in a manner consistent with the requirements of this License, to extend the patent license to downstream recipients. “Knowingly relying” means you have actual knowledge that, but for the patent license, your conveying the covered work in a country, or your recipient’s use of the covered work in a country, would infringe one or more identifiable patents in that country that you have reason to believe are valid. -If, pursuant to or in connection with a single transaction or arrangement, you convey, or propagate by procuring conveyance of, a covered work, and grant a patent license to some of the parties receiving the covered work authorizing them to use, propagate, modify or convey a specific copy of the covered work, then the patent license you grant is automatically extended to all recipients of the covered work and works based on it. -A patent license is “discriminatory” if it does not include within the scope of its coverage, prohibits the exercise of, or is conditioned on the non-exercise of one or more of the rights that are specifically granted under this License. You may not convey a covered work if you are a party to an arrangement with a third party that is in the business of distributing software, under which you make payment to the third party based on the extent of your activity of conveying the work, and under which the third party grants, to any of the parties who would receive the covered work from you, a discriminatory patent license (a) in connection with copies of the covered work conveyed by you (or copies made from those copies), or (b) primarily for and in connection with specific products or compilations that contain the covered work, unless you entered into that arrangement, or that patent license was granted, prior to 28 March 2007. -Nothing in this License shall be construed as excluding or limiting any implied license or other defenses to infringement that may otherwise be available to you under applicable patent law. -13. No Surrender of Others’ Freedom. -If conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot convey a covered work so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not convey it at all. For example, if you agree to terms that obligate you to collect a royalty for further conveying from those to whom you convey the Program, the only way you could satisfy both those terms and this License would be to refrain entirely from conveying the Program. -14. Use with the GNU Affero General Public License. -Notwithstanding any other provision of this License, you have permission to link or combine any covered work with a work licensed under version 3 of the GNU Affero General Public License into a single combined work, and to convey the resulting work. The terms of this License will continue to apply to the part which is the covered work, but the special requirements of the GNU Affero General Public License, section 13, concerning interaction through a network will apply to the combination as such. -15. Revised Versions of this License. -The Free Software Foundation may publish revised and/or new versions of the GNU General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. -Each version is given a distinguishing version number. If the Program specifies that a certain numbered version of the GNU General Public License “or any later version” applies to it, you have the option of following the terms and conditions either of that numbered version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of the GNU General Public License, you may choose any version ever published by the Free Software Foundation. -If the Program specifies that a proxy can decide which future versions of the GNU General Public License can be used, that proxy’s public statement of acceptance of a version permanently authorizes you to choose that version for the Program. -Later license versions may give you additional or different permissions. However, no additional obligations are imposed on any author or copyright holder as a result of your choosing to follow a later version. -16. Disclaimer of Warranty. -THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM “AS IS” WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. -17. Limitation of Liability. -IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. -18. Interpretation of Sections 15 and 16. -If the disclaimer of warranty and limitation of liability provided above cannot be given local legal effect according to their terms, reviewing courts shall apply local law that most closely approximates an absolute waiver of all civil liability in connection with the Program, unless a warranty or assumption of liability accompanies a copy of the Program in return for a fee. -END OF TERMS AND CONDITIONS -How to Apply These Terms to Your New Programs -If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. -To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively state the exclusion of warranty; and each file should have at least the “copyright” line and a pointer to where the full notice is found. -one line to give the program's name and a brief idea of what it does. -Copyright (C) year name of author - -This program is free software: you can redistribute it and/or modify -it under the terms of the GNU General Public License as published by -the Free Software Foundation, either version 3 of the License, or (at -your option) any later version. - -This program is distributed in the hope that it will be useful, but -WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU -General Public License for more details. - -You should have received a copy of the GNU General Public License -along with this program. If not, see http://www.gnu.org/licenses/. -Also add information on how to contact you by electronic and paper mail. -If the program does terminal interaction, make it output a short notice like this when it starts in an interactive mode: -program Copyright (C) year name of author -This program comes with ABSOLUTELY NO WARRANTY; for details type ‘show w’. -This is free software, and you are welcome to redistribute it -under certain conditions; type ‘show c’ for details. -The hypothetical commands ‘show w’ and ‘show c’ should show the appropriate parts of the General Public License. Of course, your program’s commands might be different; for a GUI interface, you would use an “about box”. -You should also get your employer (if you work as a programmer) or school, if any, to sign a “copyright disclaimer” for the program, if necessary. For more information on this, and how to apply and follow the GNU GPL, see http://www.gnu.org/licenses/. -The GNU General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Lesser General Public License instead of this License. But first, please read http://www.gnu.org/philosophy/why-not-lgpl.html. - -6. Sphinx -Copyright © 2007-2019 by the Sphinx team (see AUTHORS file). -All rights reserved. - -License for Sphinx -================== - -Copyright (c) 2007-2019 by the Sphinx team (see AUTHORS file). -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - -* Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in the - documentation and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -Licenses for incorporated software -================================== - -The included smartypants module, included as sphinx.util.smartypants, -is available under the following license: - ----------------------------------------------------------------------- -SmartyPants_ license:: - - Copyright (c) 2003 John Gruber - (https://daringfireball.net/projects/smartypants/) - All rights reserved. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials - provided with the distribution. - - * Neither the name "SmartyPants" nor the names of its - contributors may be used to endorse or promote products - derived from this software without specific prior written - permission. - - This software is provided by the copyright holders and - contributors "as is" and any express or implied warranties, - including, but not limited to, the implied warranties of - merchantability and fitness for a particular purpose are - disclaimed. In no event shall the copyright owner or contributors - be liable for any direct, indirect, incidental, special, - exemplary, or consequential damages (including, but not limited - to, procurement of substitute goods or services; loss of use, - data, or profits; or business interruption) however caused and on - any theory of liability, whether in contract, strict liability, or - tort (including negligence or otherwise) arising in any way out of - the use of this software, even if advised of the possibility of - such damage. - - -smartypants.py license:: - - smartypants.py is a derivative work of SmartyPants. - - Redistribution and use in source and binary forms, with or without - modification, are permitted provided that the following conditions - are met: - - * Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - - * Redistributions in binary form must reproduce the above - copyright notice, this list of conditions and the following - disclaimer in the documentation and/or other materials - provided with the distribution. - - This software is provided by the copyright holders and - contributors "as is" and any express or implied warranties, - including, but not limited to, the implied warranties of - merchantability and fitness for a particular purpose are - disclaimed. In no event shall the copyright owner or contributors - be liable for any direct, indirect, incidental, special, - exemplary, or consequential damages (including, but not limited - to, procurement of substitute goods or services; loss of use, - data, or profits; or business interruption) however caused and on - any theory of liability, whether in contract, strict liability, or - tort (including negligence or otherwise) arising in any way out of - the use of this software, even if advised of the possibility of - such damage. ----------------------------------------------------------------------- - -The included JQuery JavaScript library is available under the MIT -license: - ----------------------------------------------------------------------- -Copyright (c) 2008 John Resig, https://jquery.com/ - -Permission is hereby granted, free of charge, to any person obtaining -a copy of this software and associated documentation files (the -"Software"), to deal in the Software without restriction, including -without limitation the rights to use, copy, modify, merge, publish, -distribute, sublicense, and/or sell copies of the Software, and to -permit persons to whom the Software is furnished to do so, subject to -the following conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF -MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE -LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION -OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION -WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ----------------------------------------------------------------------- - -The included Underscore JavaScript library is available under the MIT -license: - ----------------------------------------------------------------------- -Copyright (c) 2009 Jeremy Ashkenas, DocumentCloud - -Permission is hereby granted, free of charge, to any person -obtaining a copy of this software and associated documentation -files (the "Software"), to deal in the Software without -restriction, including without limitation the rights to use, -copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the -Software is furnished to do so, subject to the following -conditions: - -The above copyright notice and this permission notice shall be -included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, -EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES -OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT -HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, -WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING -FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR -OTHER DEALINGS IN THE SOFTWARE. - -------------------------------------------------------------------------------- - -The included implementation of NumpyDocstring._parse_numpydoc_see_also_section -was derived from code under the following license: - -------------------------------------------------------------------------------- - -Copyright (C) 2008 Stefan van der Walt , Pauli Virtanen - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - - 1. Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - 2. Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer in - the documentation and/or other materials provided with the - distribution. - -THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR -IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, -INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) -HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, -STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING -IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - - -7. LAPACK -Copyright (c) 1992-2017 The University of Tennessee and The University of Tennessee Research Foundation. All rights reserved. -Copyright (c) 2000-2017 The University of California Berkeley. All -rights reserved. -Copyright (c) 2006-2017 The University of Colorado Denver. All rights -reserved. - -BSD-Like License - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - -- Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer listed - in this license in the documentation and/or other materials - provided with the distribution. - -- Neither the name of the copyright holders nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -The copyright holders provide no reassurances that the source code -provided does not infringe any patent, copyright, or any other -intellectual property rights of third parties. The copyright holders -disclaim any liability to any recipient for claims brought against -recipient by any third party for infringement of that parties -intellectual property rights. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -8. cuRAND back-end -Copyright (c) 2021, The Regents of the University of California, through -Lawrence Berkeley National Laboratory (subject to receipt of any required -approvals from the U.S. Dept. of Energy). All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -(1) Redistributions of source code must retain the above copyright notice, -this list of conditions and the following disclaimer. - -(2) Redistributions in binary form must reproduce the above copyright -notice, this list of conditions and the following disclaimer in the -documentation and/or other materials provided with the distribution. - -(3) Neither the name of the University of California, Lawrence Berkeley -National Laboratory, U.S. Dept. of Energy nor the names of its contributors -may be used to endorse or promote products derived from this software -without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE -POSSIBILITY OF SUCH DAMAGE. - -You are under no obligation whatsoever to provide any bug fixes, patches, -or upgrades to the features, functionality or performance of the source -code ("Enhancements") to anyone; however, if you choose to make your -Enhancements available either publicly, or directly to Lawrence Berkeley -National Laboratory, without imposing a separate written license agreement -for such Enhancements, then you hereby grant the following license: a -non-exclusive, royalty-free perpetual license to install, use, modify, -prepare derivative works, incorporate into other computer software, -distribute, and sublicense such enhancements or derivative works thereof, -in binary and source code form. - - -The following third party programs have their own third party program files. These additional third party program files are as follows: - -1. List of third party for Intel(R) oneAPI Math Kernel Library is available in the third-party-programs-onemkl file. - +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: -*Other names and brands may be claimed as the property of others. +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/third-party-programs/THIRD-PARTY-PROGRAMS-ONEMKL b/third-party-programs/THIRD-PARTY-PROGRAMS-ONEMKL deleted file mode 100644 index c1588f95e..000000000 --- a/third-party-programs/THIRD-PARTY-PROGRAMS-ONEMKL +++ /dev/null @@ -1,524 +0,0 @@ -Intel(R) oneAPI Math Kernel Library (oneMKL) Third Party Programs File - -This file is the "third-party-programs.txt" file specified in the associated Intel end user license agreement for the Intel software you are licensing. - -Third party programs and their corresponding required notices and/or license terms are listed below. - -------------------------------------------------------------- - -1. Netlib BLACS - Basic Linear Algebra Communication Subprograms: -Copyright (c) 1992-2013 The University of Tennessee and The University of Tennessee Research Foundation. All rights reserved. -Copyright (c) 2000-2013 The University of California Berkeley. All rights reserved. -Copyright (c) 2006-2013 The University of Colorado Denver. All rights reserved. - -Netlib LAPACK: -Copyright (c) 1992-2013 The University of Tennessee and The University of Tennessee Research Foundation. All rights reserved. -Copyright (c) 2000-2013 The University of California Berkeley. All rights reserved. -Copyright (c) 2006-2013 The University of Colorado Denver. All rights reserved. - -Netlib LAPACK95: -Copyright (c) 1992-2013 The University of Tennessee and The University of Tennessee Research Foundation. All rights reserved. -Copyright (c) 2000-2013 The University of California Berkeley. All rights reserved. -Copyright (c) 2006-2013 The University of Colorado Denver. All rights reserved. - -Netlib ScaLAPACK: -Copyright (c) 1992-2013 The University of Tennessee and The University of Tennessee Research Foundation. All rights reserved. -Copyright (c) 2000-2013 The University of California Berkeley. All rights reserved. -Copyright (c) 2006-2013 The University of Colorado Denver. All rights reserved. - - -Modified BSD License - -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - -- Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer listed in this license in the documentation and/or other materials provided with the distribution. - -- Neither the name of the copyright holders nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - -The copyright holders provide no reassurances that the source code provided does not infringe any patent, copyright, or any other intellectual property rights of third parties. The copyright holders disclaim any liability to any recipient for claims brought against recipient by any third party for infringement of that parties intellectual property rights. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -------------------------------------------------------------- -2. HPCG: High Performance Conjugate Gradient Benchmark: - HPCG - 3.1 - March 28, 2019 - Michael A. Heroux - Scalable Algorithms Group, Center for Computing Research - Sandia National Laboratories, Albuquerque, NM - Piotr Luszczek - Jack Dongarra - University of Tennessee, Knoxville - Innovative Computing Laboratory - (C) Copyright 2013-2019 All Rights Reserved - -Modified BSD License - -Redistribution  and  use in  source and binary forms, with or without modification, are  permitted provided  that the following  conditions are met: - -1. Redistributions  of  source  code  must retain the above copyright notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce  the above copyright notice, this list of conditions,  and the following disclaimer in the documentation and/or other materials provided with the distribution. - -3. The name of the  University,  the name of the  Laboratory,  or the names  of  its  contributors  may  not  be used to endorse or promote products  derived   from   this  software  without  specific  written permission. - --- Disclaimer: - -THIS  SOFTWARE  IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES,  INCLUDING,  BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY -OR  CONTRIBUTORS  BE  LIABLE FOR ANY  DIRECT,  INDIRECT,  INCIDENTAL, -SPECIAL,  EXEMPLARY,  OR  CONSEQUENTIAL DAMAGES  (INCLUDING,  BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA OR PROFITS; OR BUSINESS INTERRUPTION)  HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,  STRICT LIABILITY,  OR TORT(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ----------------------------------------------------------------- -3. Mersenne Twister with improved initialization: - Copyright (C) 1997 - 2002, Makoto Matsumoto and Takuji Nishimura, All rights reserved. - - Modified BSD License - -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright -notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - -3. The names of its contributors may not be used to endorse or promote products derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ----------------------------------------------------------------- -4. SFMT: - Copyright (c) 2006,2007-2014 Mutsuo Saito, Makoto Matsumoto and Hiroshima University. All rights reserved. - - Modified BSD License - -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above -copyright notice, this list of conditions and the following -disclaimer in the documentation and/or other materials provided with the distribution. -* Neither the name of the Hiroshima University nor the names of its contributors may be used to endorse or promote products -derived from this software without specific prior written -permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ----------------------------------------------------------------- -5. Sobol sequence generator: - Copyright (c) 2008, Frances Y. Kuo and Stephen Joe, All rights reserved. - - Modified BSD License - -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - -    * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - -    * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - -    * Neither the names of the copyright holders nor the names of the University of New South Wales and the University of Waikato and its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - - --------------------------------------------------------------- -6. The FEAST Eigenvalue Solver: - Copyright (c) 2009-2012, The Regents of the University of Massachusetts, Amherst. Developed by E. Polizzi. All rights reserved. - - Modified BSD License - -Redistribution and use in source and binary forms, with or without modification,are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. -2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. -3. Neither the name of the University nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF -LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - ----------------------------------------------------------------- -7. xbyak: -Copyright (c) 2007 MITSUNARI Shigeo -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -Redistributions of source code must retain the above copyright notice, this -list of conditions and the following disclaimer. -Redistributions in binary form must reproduce the above copyright notice, -this list of conditions and the following disclaimer in the documentation -and/or other materials provided with the distribution. -Neither the name of the copyright owner nor the names of its contributors may -be used to endorse or promote products derived from this software without -specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE -ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE -LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR -CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF -SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS -INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN -CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) -ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF -THE POSSIBILITY OF SUCH DAMAGE. ------------------------------------------------------------------------------ -ソースコード形式かバイナリ形式か、変更するかしないかを問わず、以下の条件を満た -す場合に限り、再頒布および使用が許可されます。 - -ソースコードを再頒布する場合、上記の著作権表示、本条件一覧、および下記免責条項 -を含めること。 -バイナリ形式で再頒布する場合、頒布物に付属のドキュメント等の資料に、上記の著作 -権表示、本条件一覧、および下記免責条項を含めること。 -書面による特別の許可なしに、本ソフトウェアから派生した製品の宣伝または販売促進 -に、著作権者の名前またはコントリビューターの名前を使用してはならない。 -本ソフトウェアは、著作権者およびコントリビューターによって「現状のまま」提供さ -れており、明示黙示を問わず、商業的な使用可能性、および特定の目的に対する適合性 -に関する暗黙の保証も含め、またそれに限定されない、いかなる保証もありません。 -著作権者もコントリビューターも、事由のいかんを問わず、 損害発生の原因いかんを -問わず、かつ責任の根拠が契約であるか厳格責任であるか(過失その他の)不法行為で -あるかを問わず、仮にそのような損害が発生する可能性を知らされていたとしても、 -本ソフトウェアの使用によって発生した(代替品または代用サービスの調達、使用の -喪失、データの喪失、利益の喪失、業務の中断も含め、またそれに限定されない)直接 -損害、間接損害、偶発的な損害、特別損害、懲罰的損害、または結果損害について、 -一切責任を負わないものとします。 - ----------------------------------------------------------------- -8. Intel(R) Instrumentation and Tracing Technology API: - Copyright(c) 2005-2014, Intel Corporation, All rights reserved. - - Modified BSD License - -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: -* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. -* Redistributions in binary form must reproduce the above copyright notice, his list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. -* Neither the name of Intel Corporation nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - - ------------------------------------------------------- -9. Intel Open Source Technology Center Safe String Library: - strcat_s.c - October 2008, Bo Berry - Copyright (c) 2008-2011 by Cisco Systems, Inc. All rights reserved. - - - level-zero: - Copyright (c) 2019 Intel Corporation - - MIT License - -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - -------------------------------------------------------------- -10. Khronos Group - OpenCL: Copyright (c) 2008 - 2013 The Khronos Group Inc. - - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. - - "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work(an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to -communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable(except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and -attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution -notices within Derivative Works that You distribute, alongside -or as an addendum to the NOTICE text from the Work, provided -that such additional attribution notices cannot be construed as modifying the License. - - You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each -Contributor provides its Contributions) on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. - - Copyright [yyyy] [name of copyright owner] - - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and limitations under the License. - -------------------------------------------------------------- -11. HPL - High-Performance Linpack Benchmark: - HPL - 2.3 - December 2, 2018 - Antoine P. Petitet - University of Tennessee, Knoxville - Innovative Computing Laboratory - (C) Copyright 2000-2008 All Rights Reserved - - HPL License - - Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - - 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - - 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions, and the following disclaimer in the documentation and/or other materials provided with the distribution. - - 3. All advertising materials mentioning features or use of this software must display the following acknowledgement: - This product includes software developed at the University of Tennessee, Knoxville, Innovative Computing Laboratory. - - 4. The name of the University, the name of the Laboratory, or the names of its contributors may not be used to endorse or promote products derived from this software without specific written permission. - - -- Disclaimer: - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE UNIVERSITY OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - - -------------------------------------------------------------- -12. Intel(R) oneAPI Threading Building Blocks: - Copyright 2005-2019 Intel Corporation.  All Rights Reserved. - - Intel(R) Integrated Performance Primitives Library - Copyright 2020 Intel Corporation. All Rights Reserved. - - Intel(R) OpenMP* Runtime: - Copyright 1985-2019 Intel Corporation.  All Rights Reserved. - - The source code contained or described herein and all documents related to the source code ("Material") are owned by Intel Corporation or its suppliers or licensors.  Title to the Material remains with Intel Corporation or its suppliers and licensors.  The Material is protected by worldwide copyright laws and treaty provisions.  No part of the Material may be used, copied, reproduced, modified, published, uploaded, posted, transmitted, distributed, or disclosed in any way without Intel's prior express written permission. - - No license under any patent, copyright, trade secret or other intellectual property right is granted to or conferred upon you by disclosure or delivery of the Materials, either expressly, by implication, inducement, estoppel or otherwise.  Any license under such intellectual property rights must be express and approved by Intel in writing. - - Portions of this software are protected under the following patents: -        U.S. Patent 5,812,852 -        U.S. Patent 6,792,599 -        U.S. Patent 7,069,556 -        U.S. Patent 7,328,433 -        U.S. Patent 7,500,242 - - - - - - Intel Simplified Software License (Version April 2018) - -Copyright (c) 2018 Intel Corporation. - -Use and Redistribution. You may use and redistribute the software (the “Software”), without modification, provided the following conditions are met: - -* Redistributions must reproduce the above copyright notice and the following terms of use in the Software and in the documentation and/or other materials provided with the distribution. -* Neither the name of Intel nor the names of its suppliers may be used to endorse or promote products derived from this Software without specific prior written permission. -* No reverse engineering, decompilation, or disassembly of this Software is permitted. - -Limited patent license. Intel grants you a world-wide, royalty-free, non-exclusive license under patents it now or hereafter owns or controls to make, have made, use, import, offer to sell and sell (“Utilize”) this Software, but solely to the extent that any such patent is necessary to Utilize the Software alone. The patent license shall not apply to any combinations which include this software. No hardware per se is licensed hereunder. - -Third party and other Intel programs. “Third Party Programs” are the files listed in the “third-party-programs.txt” text file that is included with the Software and may include Intel programs under separate license terms. Third Party Programs, even if included with the distribution of the Materials, are governed by separate license terms and those license terms solely govern your use of those programs. - -DISCLAIMER. THIS SOFTWARE IS PROVIDED "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT ARE DISCLAIMED. THIS SOFTWARE IS NOT INTENDED FOR USE IN SYSTEMS OR APPLICATIONS WHERE FAILURE OF THE SOFTWARE MAY CAUSE PERSONAL INJURY OR DEATH AND YOU AGREE THAT YOU ARE FULLY RESPONSIBLE FOR ANY CLAIMS, COSTS, DAMAGES, EXPENSES, AND ATTORNEYS’ FEES ARISING OUT OF ANY SUCH USE, EVEN IF ANY CLAIM ALLEGES THAT INTEL WAS NEGLIGENT REGARDING THE DESIGN OR MANUFACTURE OF THE MATERIALS. - -LIMITATION OF LIABILITY. IN NO EVENT WILL INTEL BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. YOU AGREE TO INDEMNIFY AND HOLD INTEL HARMLESS AGAINST ANY CLAIMS AND EXPENSES RESULTING FROM YOUR USE OR UNAUTHORIZED USE OF THE SOFTWARE. - -No support. Intel may make changes to the Software, at any time without notice, and is not obligated to support, update or provide training for the Software. - -Termination. Intel may terminate your right to use the Software in the event of your breach of this Agreement and you fail to cure the breach within a reasonable period of time. - -Feedback. Should you provide Intel with comments, modifications, corrections, enhancements or other input (“Feedback”) related to the Software Intel will be free to use, disclose, reproduce, license or otherwise distribute or exploit the Feedback in its sole discretion without any obligations or restrictions of any kind, including without limitation, intellectual property rights or licensing obligations. - -Compliance with laws. You agree to comply with all relevant laws and regulations governing your use, transfer, import or export (or prohibition thereof) of the Software. - -Governing law. All disputes will be governed by the laws of the United States of America and the State of Delaware without reference to conflict of law principles and subject to the exclusive jurisdiction of the state or federal courts sitting in the State of Delaware, and each party agrees that it submits to the personal jurisdiction and venue of those courts and waives any objections. The United Nations Convention on Contracts for the International Sale of Goods (1980) is specifically excluded and will not apply to the Software. - -*Other names and brands may be claimed as the property of others. - - ----------------------------------------------------------------- - -13. Microsoft HPC Pack Software Development Kit (SDK) - - Micrososft Software License - - MICROSOFT SOFTWARE LICENSE TERMS -MICROSOFT HPC PACK SOFTWARE DEVELOPMENT KIT (SDK) -Last updated: August 2018 -These license terms are an agreement between Microsoft Corporation (or based on where you live, one of its affiliates) and you. Please read them. They apply to the software named above, which includes the media on which you received it, if any. The terms also apply to any Microsoft -updates, -supplements, -Internet-based services, and -support services -for this software, unless other terms accompany those items. If so, those terms apply. -By using the software, you accept these terms. If you do not accept them, do not use the software. -If you comply with these license terms, you have the perpetual rights below. -INSTALLATION AND USE RIGHTS. You may install and use any number of copies of the software on your devices to design, develop and test your programs. Subject to your compliance with the terms below, you are permitted to distribute the software, in object code form only, in programs you develop and you may permit distributors of your programs to copy and distribute the software as part of those programs. -Distribution Requirements. For any software you distribute in your programs, you must -add significant primary functionality to it in your programs; -for any Distributable Code having a filename extension of .lib, distribute only the results of running such Distributable Code through a linker with your program; -require distributors and external end users to agree to terms that protect it at least as much as this agreement; -display your valid copyright notice on your programs; and -indemnify, defend, and hold harmless Microsoft from any claims, including attorneys’ fees, related to the distribution or use of your programs. -ii. Distribution Restrictions. You may not -alter any copyright, trademark or patent notice in the software; -use Microsoft’s trademarks in your programs’ names or in a way that suggests your programs come from or are endorsed by Microsoft; -distribute software to run on a platform other than the Windows platform; -include software in malicious, deceptive or unlawful programs; or -modify or distribute the software so that any part of it becomes subject to an Excluded License. An Excluded License is one that requires, as a condition of use, modification or distribution, that -the code be disclosed or distributed in source code form; or -others have the right to modify it. -2. SCOPE OF LICENSE. The software is licensed, not sold. This agreement only gives you some rights to use the software. Microsoft reserves all other rights. Unless applicable law gives you more rights despite this limitation, you may use the software only as expressly permitted in this agreement. In doing so, you must comply with any technical limitations in the software that only allow you to use it in certain ways. You may not -work around any technical limitations in the software; -reverse engineer, decompile or disassemble the software, except and only to the extent that applicable law expressly permits, despite this limitation; -publish the software for others to copy; -rent, lease or lend the software; -transfer the software or this agreement to any third party; or -use the software for commercial software hosting services. -3. DOCUMENTATION. Any person that has valid access to your computer or internal network may copy and use the documentation for your internal, reference purposes. -4. EXPORT RESTRICTIONS. You must comply with all domestic and international export laws and regulations that apply to the software, which include restrictions on destinations, end users, and end use. For further information on export restrictions, visit http://aka.ms/exporting. -5. SUPPORT SERVICES. Microsoft is not obligated under this agreement to provide any support services for the software. Any support provided is “as is”, “with all faults”, and without warranty of any kind. -6. UPDATES. The software may periodically check for updates, and download and install them for you. You may obtain updates only from Microsoft or authorized sources. Microsoft may need to update your system to provide you with updates. You agree to receive these automatic updates without any additional notice. Updates may not include or support all existing software features, services, or peripheral devices. -7. BINDING ARBITRATION AND CLASS ACTION WAIVER. This Section applies if you live in (or, if a business, your principal place of business is in) the United States. If you and Microsoft have a dispute, you and Microsoft agree to try for 60 days to resolve it informally. If you and Microsoft can’t, you and Microsoft agree to binding individual arbitration before the American Arbitration Association under the Federal Arbitration Act (“FAA”), and not to sue in court in front of a judge or jury. Instead, a neutral arbitrator will decide. Class action lawsuits, class-wide arbitrations, private attorney-general actions, and any other proceeding where someone acts in a representative capacity are not allowed; nor is combining individual proceedings without the consent of all parties. The complete Arbitration Agreement contains more terms and is at http://aka.ms/arb-agreement-1. You and Microsoft agree to these terms. -8. ENTIRE AGREEMENT. This agreement, and any other terms Microsoft may provide for supplements, updates, or third-party applications, is the entire agreement for the software. -9. APPLICABLE LAW AND PLACE TO RESOLVE DISPUTES. If you acquired the software in the United States or Canada, the laws of the state or province where you live (or, if a business, where your principal place of business is located) govern the interpretation of this agreement, claims for its breach, and all other claims (including consumer protection, unfair competition, and tort claims), regardless of conflict of laws principles, except that the FAA governs everything related to arbitration. If you acquired the software in any other country, its laws apply, except that the FAA governs everything related to arbitration. If U.S. federal jurisdiction exists, you and Microsoft consent to exclusive jurisdiction and venue in the federal court in King County, Washington for all disputes heard in court (excluding arbitration). If not, you and Microsoft consent to exclusive jurisdiction and venue in the Superior Court of King County, Washington for all disputes heard in court (excluding arbitration). -10. CONSUMER RIGHTS; REGIONAL VARIATIONS. This agreement describes certain legal rights. You may have other rights, including consumer rights, under the laws of your state or country. Separate and apart from your relationship with Microsoft, you may also have rights with respect to the party from which you acquired the software. This agreement does not change those other rights if the laws of your state or country do not permit it to do so. For example, if you acquired the software in one of the below regions, or mandatory country law applies, then the following provisions apply to you: -Australia. You have statutory guarantees under the Australian Consumer Law and nothing in this agreement is intended to affect those rights. -Canada. If you acquired this software in Canada, you may stop receiving updates by turning off the automatic update feature, disconnecting your device from the Internet (if and when you re-connect to the Internet, however, the software will resume checking for and installing updates), or uninstalling the software. The product documentation, if any, may also specify how to turn off updates for your specific device or software. -Germany and Austria. -Warranty. The properly licensed software will perform substantially as described in any Microsoft materials that accompany the software. However, Microsoft gives no contractual guarantee in relation to the licensed software. -Limitation of Liability. In case of intentional conduct, gross negligence, claims based on the Product Liability Act, as well as, in case of death or personal or physical injury, Microsoft is liable according to the statutory law. -Subject to the foregoing clause ii., Microsoft will only be liable for slight negligence if Microsoft is in breach of such material contractual obligations, the fulfillment of which facilitate the due performance of this agreement, the breach of which would endanger the purpose of this agreement and the compliance with which a party may constantly trust in (so-called "cardinal obligations"). In other cases of slight negligence, Microsoft will not be liable for slight negligence. -11. DISCLAIMER OF WARRANTY. THE SOFTWARE IS LICENSED “AS IS.” YOU BEAR THE RISK OF USING IT. MICROSOFT GIVES NO EXPRESS WARRANTIES, GUARANTEES, OR CONDITIONS. TO THE EXTENT PERMITTED UNDER APPLICABLE LAWS, MICROSOFT EXCLUDES ALL IMPLIED WARRANTIES, INCLUDING MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. -12. LIMITATION ON AND EXCLUSION OF DAMAGES. IF YOU HAVE ANY BASIS FOR RECOVERING DAMAGES DESPITE THE PRECEDING DISCLAIMER OF WARRANTY, YOU CAN RECOVER FROM MICROSOFT AND ITS SUPPLIERS ONLY DIRECT DAMAGES UP TO U.S. $5.00. YOU CANNOT RECOVER ANY OTHER DAMAGES, INCLUDING CONSEQUENTIAL, LOST PROFITS, SPECIAL, INDIRECT, OR INCIDENTAL DAMAGES. -This limitation applies to (a) anything related to the software, services, content (including code) on third party Internet sites, or third party applications; and (b) claims for breach of contract, warranty, guarantee, or condition; strict liability, negligence, or other tort; or any other claim; in each case to the extent permitted by applicable law. -It also applies even if Microsoft knew or should have known about the possibility of the damages. The above limitation or exclusion may not apply to you because your state, province, or country may not allow the exclusion or limitation of incidental, consequential, or other damages. - - -Please note: As this software is distributed in Canada, some of the clauses in this agreement are provided below in French. -Remarque: Ce logiciel étant distribué au Canada, certaines des clauses dans ce contrat sont fournies ci-dessous en français. -EXONÉRATION DE GARANTIE. Le logiciel visé par une licence est offert « tel quel ». Toute utilisation de ce logiciel est à votre seule risque et péril. Microsoft n’accorde aucune autre garantie expresse. Vous pouvez bénéficier de droits additionnels en vertu du droit local sur la protection des consommateurs, que ce contrat ne peut modifier. La ou elles sont permises par le droit locale, les garanties implicites de qualité marchande, d’adéquation à un usage particulier et d’absence de contrefaçon sont exclues. -LIMITATION DES DOMMAGES-INTÉRÊTS ET EXCLUSION DE RESPONSABILITÉ POUR LES DOMMAGES. Vous pouvez obtenir de Microsoft et de ses fournisseurs une indemnisation en cas de dommages directs uniquement à hauteur de 5,00 $ US. Vous ne pouvez prétendre à aucune indemnisation pour les autres dommages, y compris les dommages spéciaux, indirects ou accessoires et pertes de bénéfices. -Cette limitation concerne: -• tout ce qui est relié au logiciel, aux services ou au contenu (y compris le code) figurant sur des sites Internet tiers ou dans des programmes tiers; et -• les réclamations au titre de violation de contrat ou de garantie, ou au titre de responsabilité stricte, de négligence ou d’une autre faute dans la limite autorisée par la loi en vigueur. -Elle s’applique également, même si Microsoft connaissait ou devrait connaître l’éventualité d’un tel dommage. Si votre pays n’autorise pas l’exclusion ou la limitation de responsabilité pour les dommages indirects, accessoires ou de quelque nature que ce soit, il se peut que la limitation ou l’exclusion ci-dessus ne s’appliquera pas à votre égard. -EFFET JURIDIQUE. Le présent contrat décrit certains droits juridiques. Vous pourriez avoir d’autres droits prévus par les lois de votre pays. Le présent contrat ne modifie pas les droits que vous confèrent les lois de votre pays si celles-ci ne le permettent pas. - -*Other names and brands may be claimed as the property of others. - ----------------------------------------------------------------- -14. Netlib XBLAS - Extra Precise Basic Linear Algebra Subroutines: - Copyright (c) 2008-2009 The University of California Berkeley. All rights reserved. - -$COPYRIGHT$ - -Additional copyrights may follow - -$HEADER$ - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are -met: - -- Redistributions of source code must retain the above copyright - notice, this list of conditions and the following disclaimer. - -- Redistributions in binary form must reproduce the above copyright - notice, this list of conditions and the following disclaimer listed - in this license in the documentation and/or other materials - provided with the distribution. - -- Neither the name of the copyright holders nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - ----------------------------------------------------------------- - - -The following third party programs have their own third party program files. These additional third party program files are as follows: - -1. List of third party for Intel Open Source Technology Center Safe String Library is available in the third-party-programs-safestring.txt. - -2. List of third party for Intel(R) oneAPI Threading Building Blocks for Linux is available in /tbb//licensing/third-party-programs.txt. - -3. List of third party for Intel(R) oneAPI Threading Building Blocks for Windows is available in \tbb\\licensing\third-party-programs.txt. - -4. List of third party for Intel(R) OpenMP* Runtime is available in the third-party-programs-openmp.txt file. - -5. List of third party for Intel(R) Integrated Performance Primitives Library is available in the third-party-programs-ipp.txt file. - -*Other names and brands may be claimed as the property of others. - - - - - -