diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 3d08f5a..2d0c2ac 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -12,52 +12,47 @@ defaults: shell: bash -e -l {0} jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 name: ${{ matrix.sys.compiler }} ${{ matrix.sys.version }} - ${{ matrix.sys.blas }} strategy: fail-fast: false matrix: sys: - - {compiler: gcc, version: '8', blas: OpenBLAS} - - {compiler: gcc, version: '8', blas: mkl} - - {compiler: gcc, version: '9', blas: OpenBLAS} - - {compiler: gcc, version: '9', blas: mkl} - - {compiler: gcc, version: '10', blas: OpenBLAS} - - {compiler: gcc, version: '10', blas: mkl} - {compiler: gcc, version: '11', blas: OpenBLAS} - {compiler: gcc, version: '11', blas: mkl} - - {compiler: clang, version: '15', blas: OpenBLAS} - - {compiler: clang, version: '15', blas: mkl} - - {compiler: clang, version: '16', blas: OpenBLAS} - - {compiler: clang, version: '16', blas: mkl} + - {compiler: gcc, version: '12', blas: OpenBLAS} + - {compiler: gcc, version: '12', blas: mkl} + - {compiler: gcc, version: '13', blas: OpenBLAS} + - {compiler: gcc, version: '13', blas: mkl} + - {compiler: gcc, version: '14', blas: OpenBLAS} + - {compiler: gcc, version: '14', blas: mkl} + - {compiler: clang, version: '17', blas: OpenBLAS} + - {compiler: clang, version: '18', blas: mkl} + - {compiler: clang, version: '19', blas: OpenBLAS} + - {compiler: clang, version: '20', blas: mkl} steps: + - name: Install GCC + if: matrix.sys.compiler == 'gcc' + uses: egor-tensin/setup-gcc@v1 + with: + version: ${{matrix.sys.version}} + platform: x64 - - name: Setup GCC - if: ${{ matrix.sys.compiler == 'gcc' }} + - name: Install LLVM and Clang + if: matrix.sys.compiler == 'clang' run: | - GCC_VERSION=${{ matrix.sys.version }} - sudo apt-get update - sudo apt-get --no-install-suggests --no-install-recommends install g++-$GCC_VERSION - CC=gcc-$GCC_VERSION - echo "CC=$CC" >> $GITHUB_ENV - CXX=g++-$GCC_VERSION - echo "CXX=$CXX" >> $GITHUB_ENV + wget https://apt.llvm.org/llvm.sh + chmod +x llvm.sh + sudo ./llvm.sh ${{matrix.sys.version}} + sudo apt-get install -y clang-tools-${{matrix.sys.version}} + sudo update-alternatives --install /usr/bin/clang++ clang++ /usr/bin/clang++-${{matrix.sys.version}} 200 + sudo update-alternatives --install /usr/bin/clang clang /usr/bin/clang-${{matrix.sys.version}} 200 + sudo update-alternatives --install /usr/bin/clang-scan-deps clang-scan-deps /usr/bin/clang-scan-deps-${{matrix.sys.version}} 200 + sudo update-alternatives --set clang /usr/bin/clang-${{matrix.sys.version}} + sudo update-alternatives --set clang++ /usr/bin/clang++-${{matrix.sys.version}} + sudo update-alternatives --set clang-scan-deps /usr/bin/clang-scan-deps-${{matrix.sys.version}} - - name: Setup clang - if: ${{ matrix.sys.compiler == 'clang' }} - run: | - LLVM_VERSION=${{ matrix.sys.version }} - wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - || exit 1 - sudo add-apt-repository "deb http://apt.llvm.org/focal/ llvm-toolchain-focal-$LLVM_VERSION main" || exit 1 - sudo apt-get update || exit 1 - sudo apt-get --no-install-suggests --no-install-recommends install clang-$LLVM_VERSION || exit 1 - sudo apt-get --no-install-suggests --no-install-recommends install g++-9 g++-9-multilib || exit 1 - sudo ln -s /usr/include/asm-generic /usr/include/asm - CC=clang-$LLVM_VERSION - echo "CC=$CC" >> $GITHUB_ENV - CXX=clang++-$LLVM_VERSION - echo "CXX=$CXX" >> $GITHUB_ENV - name: Checkout code uses: actions/checkout@v3 @@ -76,10 +71,10 @@ jobs: - name: Install OpenBLAS if: ${{ matrix.sys.blas == 'OpenBLAS' }} - run: micromamba install openblas==0.3 blas-devel + run: micromamba install 'openblas==0.3.29=pthreads*' blas-devel - name: Configure using CMake - run: cmake -Bbuild -DDOWNLOAD_GTEST=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX -DCMAKE_C_COMPILER=$CC -DCMAKE_CXX_COMPILER=$CXX -DCMAKE_SYSTEM_IGNORE_PATH=/usr/lib + run: cmake -Bbuild -DDOWNLOAD_GTEST=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DCMAKE_SYSTEM_IGNORE_PATH=/usr/lib - name: Build working-directory: build diff --git a/.github/workflows/osx.yml b/.github/workflows/osx.yml index 7d99475..41ac08c 100644 --- a/.github/workflows/osx.yml +++ b/.github/workflows/osx.yml @@ -13,13 +13,14 @@ defaults: jobs: build: runs-on: macos-${{ matrix.os }} - name: macos-${{ matrix.os }} - mkl + name: macos-${{ matrix.os }} - OpenBLAS strategy: fail-fast: false matrix: os: - - 11 - - 12 + - 13 + - 14 + - 15 steps: @@ -34,8 +35,8 @@ jobs: init-shell: bash cache-downloads: true - - name: Install mkl - run: micromamba install mkl + - name: Install OpenBLAS + run: micromamba install 'openblas==0.3.29=openmp*' blas-devel - name: Configure using CMake run: cmake -Bbuild -DDOWNLOAD_GTEST=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX -DCMAKE_SYSTEM_IGNORE_PATH=/usr/lib diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0a401f4..d957878 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.5.0 hooks: - id: check-added-large-files - id: check-case-conflict @@ -16,28 +16,28 @@ repos: - id: detect-private-key - id: check-merge-conflict - repo: https://github.com/Lucas-C/pre-commit-hooks - rev: v1.3.1 + rev: v1.5.4 hooks: - id: forbid-tabs - id: remove-tabs args: [--whitespaces-count, '4'] - repo: https://github.com/macisamuele/language-formatters-pre-commit-hooks - rev: v2.5.0 + rev: v2.11.0 hooks: - id: pretty-format-yaml args: [--autofix, --indent, '2'] - repo: https://github.com/tdegeus/cpp_comment_format - rev: v0.2.0 + rev: v0.2.1 hooks: - id: cpp_comment_format - repo: https://github.com/tdegeus/conda_envfile - rev: v0.4.1 + rev: v0.4.2 hooks: - id: conda_envfile_parse files: environment.yaml # Externally provided executables (so we can use them with editors as well). - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v15.0.6 + rev: v17.0.6 hooks: - id: clang-format files: .*\.[hc]pp$ diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b3b6ac..eaf10e2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -7,7 +7,7 @@ # The full license is in the file LICENSE, distributed with this software. # ############################################################################ -cmake_minimum_required(VERSION 3.8) +cmake_minimum_required(VERSION 3.29) project(xtensor-blas) set(INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/include) diff --git a/benchmark/CMakeLists.txt b/benchmark/CMakeLists.txt index 3c7617d..aeb40ea 100644 --- a/benchmark/CMakeLists.txt +++ b/benchmark/CMakeLists.txt @@ -6,7 +6,7 @@ # The full license is in the file LICENSE, distributed with this software. # ############################################################################ -cmake_minimum_required(VERSION 3.1) +cmake_minimum_required(VERSION 3.29) if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) project(xtensor-benchmark) diff --git a/environment-dev.yml b/environment-dev.yml index 803bb77..ea65f05 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -3,4 +3,4 @@ channels: - conda-forge dependencies: - cmake -- xtensor>=0.25.0 +- xtensor>=0.25.0,<0.26 diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 4817e13..2b555dc 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -7,7 +7,7 @@ # The full license is in the file LICENSE, distributed with this software. # ############################################################################ -cmake_minimum_required(VERSION 3.8) +cmake_minimum_required(VERSION 3.29) if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) project(xtensor-blas-test) diff --git a/test/downloadGTest.cmake.in b/test/downloadGTest.cmake.in index 23ae933..e06bb06 100644 --- a/test/downloadGTest.cmake.in +++ b/test/downloadGTest.cmake.in @@ -7,14 +7,14 @@ # The full license is in the file LICENSE, distributed with this software. # ############################################################################ -cmake_minimum_required(VERSION 2.8.2) +cmake_minimum_required(VERSION 3.29) project(googletest-download NONE) include(ExternalProject) ExternalProject_Add(googletest - GIT_REPOSITORY https://github.com/JohanMabille/googletest.git - GIT_TAG warnings + GIT_REPOSITORY https://github.com/google/googletest.git + GIT_TAG v1.16.0 SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-build" CONFIGURE_COMMAND "" diff --git a/test/test_dot_extended.cpp b/test/test_dot_extended.cpp index 4764566..ea1906f 100644 --- a/test/test_dot_extended.cpp +++ b/test/test_dot_extended.cpp @@ -33,17 +33,20 @@ namespace xt xarray py_a = { {{0.3745401188473625, 0.9507143064099162, 0.7319939418114051, 0.5986584841970366, 0.1560186404424365}, {0.1559945203362026, 0.0580836121681995, 0.8661761457749352, 0.6011150117432088, 0.7080725777960455}, - {0.0205844942958024, 0.9699098521619943, 0.8324426408004217, 0.2123391106782762, 0.1818249672071006}}, + {0.0205844942958024, 0.9699098521619943, 0.8324426408004217, 0.2123391106782762, 0.1818249672071006 + }}, {{0.1834045098534338, 0.3042422429595377, 0.5247564316322378, 0.4319450186421158, 0.2912291401980419}, {0.6118528947223795, 0.1394938606520418, 0.2921446485352182, 0.3663618432936917, 0.4560699842170359}, - {0.7851759613930136, 0.1996737821583597, 0.5142344384136116, 0.5924145688620425, 0.0464504127199977}}}; + {0.7851759613930136, 0.1996737821583597, 0.5142344384136116, 0.5924145688620425, 0.0464504127199977}} + }; // py_b xarray py_b = { {0.6075448519014384, 0.1705241236872915, 0.0650515929852795, 0.9488855372533332, 0.9656320330745594}, {0.8083973481164611, 0.3046137691733707, 0.0976721140063839, 0.6842330265121569, 0.4401524937396013}, {0.1220382348447788, 0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169}, - {0.662522284353982, 0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527}}; + {0.662522284353982, 0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527} + }; // py_dr xarray py_dr = { {{1.1560019913607258, 1.1421672030085086, 1.1263990512143978, 1.2813094834150083}, @@ -52,7 +55,8 @@ namespace xt {{0.8885299172713558, 0.7159304454839006, 0.6592223836380569, 0.7792380767202456}, {1.2025508600129964, 1.0170636073271262, 0.6049520893427571, 0.8853834024749684}, - {1.15151820221699, 1.1715787914743192, 0.763094187597877, 1.182339688054495}}}; + {1.15151820221699, 1.1715787914743192, 0.763094187597877, 1.182339688054495}} + }; xt::xtensor bas = xt::transpose(py_b); @@ -78,18 +82,16 @@ namespace xt {{0.5426960831582485, 0.1409242249747626, 0.8021969807540397, 0.0745506436797708, 0.9868869366005173}, {0.7722447692966574, 0.1987156815341724, 0.0055221171236024, 0.8154614284548342, 0.7068573438476171}, - {0.7290071680409873, 0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297}}}; + {0.7290071680409873, 0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297}} + }; // py_b - xarray py_b = { - 0.8631034258755935, - 0.6232981268275579, - 0.3308980248526492, - 0.0635583502860236, - 0.3109823217156622}; + xarray py_b = + {0.8631034258755935, 0.6232981268275579, 0.3308980248526492, 0.0635583502860236, 0.3109823217156622}; // py_dr xarray py_dr = { {1.8736790686065976, 1.0197269167779506, 0.8888679673881792}, - {1.1333287572487494, 1.0638629967411402, 1.1932578950872312}}; + {1.1333287572487494, 1.0638629967411402, 1.1932578950872312} + }; auto xres = xt::linalg::dot(py_a, py_b); std::cout << xres << std::endl; @@ -118,7 +120,8 @@ namespace xt {{100, 101, 102, 103, 104}, {105, 106, 107, 108, 109}, {110, 111, 112, 113, 114}, - {115, 116, 117, 118, 119}}}}; + {115, 116, 117, 118, 119}}} + }; // py_b xarray py_b = { {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}, {12, 13, 14}}, @@ -127,7 +130,8 @@ namespace xt {{30, 31, 32}, {33, 34, 35}, {36, 37, 38}, {39, 40, 41}, {42, 43, 44}}, - {{45, 46, 47}, {48, 49, 50}, {51, 52, 53}, {54, 55, 56}, {57, 58, 59}}}; + {{45, 46, 47}, {48, 49, 50}, {51, 52, 53}, {54, 55, 56}, {57, 58, 59}} + }; // py_dr xarray py_dr = { {{{{90, 100, 110}, {240, 250, 260}, {390, 400, 410}, {540, 550, 560}}, @@ -176,7 +180,8 @@ namespace xt {{3390, 3950, 4510}, {11790, 12350, 12910}, {20190, 20750, 21310}, {28590, 29150, 29710}}, - {{3540, 4125, 4710}, {12315, 12900, 13485}, {21090, 21675, 22260}, {29865, 30450, 31035}}}}}; + {{3540, 4125, 4710}, {12315, 12900, 13485}, {21090, 21675, 22260}, {29865, 30450, 31035}}}} + }; auto xres = xt::linalg::dot(py_a, py_b); EXPECT_TRUE(xt::allclose(xres, py_dr)); diff --git a/test/test_lapack.cpp b/test/test_lapack.cpp index 7121631..2a1ae35 100644 --- a/test/test_lapack.cpp +++ b/test/test_lapack.cpp @@ -29,7 +29,8 @@ namespace xt {0.01114647, 0.93096724, 0.35509599, 0.35329223, 0.65759337}, {0.27868701, 0.376794, 0.63310696, 0.90892131, 0.35454718}, {0.02962539, 0.20561053, 0.2004051, 0.83641883, 0.08335324}, - {0.76958296, 0.23132089, 0.33539779, 0.70616527, 0.40256713}}; + {0.76958296, 0.23132089, 0.33539779, 0.70616527, 0.40256713} + }; auto eig_res = xt::linalg::eig(eig_arg_0); xtensor, 1> eig_expected_0 = { @@ -37,7 +38,8 @@ namespace xt 0.24898158 + 0.51158566i, 0.24898158 - 0.51158566i, 0.66252212 + 0.i, - 0.28854321 + 0.i}; + 0.28854321 + 0.i + }; xtensor, 2> eig_expected_1 = { {-0.67843725 + 0.i, @@ -50,17 +52,15 @@ namespace xt -0.42892828 + 0.30675499i, -0.60497432 + 0.i, -0.55233486 + 0.i}, - {-0.39453548 + 0.i, - 0.10153693 - 0.12657944i, - 0.10153693 + 0.12657944i, - 0.35111489 + 0.i, - 0.80267297 + 0.i}, + {-0.39453548 + 0.i, 0.10153693 - 0.12657944i, 0.10153693 + 0.12657944i, 0.35111489 + 0.i, 0.80267297 + 0.i + }, {-0.15349367 + 0.i, -0.04903747 + 0.08226059i, -0.04903747 - 0.08226059i, 0.48726345 + 0.i, -0.10533951 + 0.i}, - {-0.46162383 + 0.i, 0.65501769 + 0.i, 0.65501769 - 0.i, -0.19620376 + 0.i, 0.16463982 + 0.i}}; + {-0.46162383 + 0.i, 0.65501769 + 0.i, 0.65501769 - 0.i, -0.19620376 + 0.i, 0.16463982 + 0.i} + }; xarray> eigvals = std::get<0>(eig_res); xarray> eigvecs = std::get<1>(eig_res); @@ -76,19 +76,22 @@ namespace xt {0.24, 0.39, 0.42, -0.16}, {0.39, -0.11, 0.79, 0.63}, {0.42, 0.79, -0.25, 0.48}, - {-0.16, 0.63, 0.48, -0.03}}; + {-0.16, 0.63, 0.48, -0.03} + }; xarray eig_arg_1 = { {4.16, -3.12, 0.56, -0.10}, {-3.12, 5.03, -0.83, 1.09}, {0.56, -0.83, 0.76, 0.34}, - {-0.10, 1.09, 0.34, 1.18}}; + {-0.10, 1.09, 0.34, 1.18} + }; auto eig_res = xt::linalg::eigh(eig_arg_0, eig_arg_1); xtensor eig_expected_0 = {-2.225448, -0.454756, 0.100076, 1.127039}; xtensor eig_expected_1 = { {0.031913, 0.327020, 0.682699, 0.425628}, {0.265466, 0.565845, 0.056645, 0.520961}, {0.713483, -0.371290, -0.077102, 0.714215}, - {-0.647650, -0.659561, -0.724409, -0.193227}}; + {-0.647650, -0.659561, -0.724409, -0.193227} + }; xarray eigvals = std::get<0>(eig_res); xarray eigvecs = std::get<1>(eig_res); for (unsigned i = 0; i < 4; ++i) @@ -114,7 +117,8 @@ namespace xt xarray expected = { {0.55555556, -0.11111111, -0.22222222}, {0.22222222, 0.55555556, 0.11111111}, - {-0.33333333, -0.33333333, 0.33333333}}; + {-0.33333333, -0.33333333, 0.33333333} + }; EXPECT_TRUE(allclose(expected, t)); @@ -166,7 +170,8 @@ namespace xt {0.44615865, 0.89495389, 0., 0., 0.}, {0.39541532, 0.24253783, 0.88590187, 0., 0.}, {-0.36681098, -0.26249522, 0.0338034, 0.89185386, 0.}, - {0.0881614, 0.12356345, 0.19887529, -0.35996807, 0.89879433}}; + {0.0881614, 0.12356345, 0.19887529, -0.35996807, 0.89879433} + }; xarray b = {1, 1, 1, -1, -1}; auto x = linalg::solve_cholesky(A, b); @@ -176,7 +181,8 @@ namespace xt 0.26609253571318064, 1.03715526610177222, -1.3449222878385465, - -1.81183493755905478}; + -1.81183493755905478 + }; for (std::size_t i = 0; i < x_expected.shape()[0]; ++i) { @@ -191,7 +197,8 @@ namespace xt {0.44615865, 0.89495389, 0., 0., 0.}, {0.39541532, 0.24253783, 0.88590187, 0., 0.}, {-0.36681098, -0.26249522, 0.0338034, 0.89185386, 0.}, - {0.0881614, 0.12356345, 0.19887529, -0.35996807, 0.89879433}}; + {0.0881614, 0.12356345, 0.19887529, -0.35996807, 0.89879433} + }; const xt::xtensor b = {0.38867999, 0.46467046, 0.39042938, -0.2736973, 0.20813322}; auto x = linalg::solve_triangular(A, b); @@ -201,7 +208,8 @@ namespace xt 0.32544416381003327, 0.17813128230545805, -0.05799057434472885, - 0.08606304705465571}; + 0.08606304705465571 + }; for (std::size_t i = 0; i < x_expected.shape()[0]; ++i) { diff --git a/test/test_linalg.cpp b/test/test_linalg.cpp index e73ccd7..5df2feb 100644 --- a/test/test_linalg.cpp +++ b/test/test_linalg.cpp @@ -39,7 +39,8 @@ namespace xt xarray t3expected = { {1.06199622e+45, 1.36986674e+45, 1.67773727e+45}, {3.26000325e+45, 4.20507151e+45, 5.15013977e+45}, - {5.45801029e+45, 7.04027628e+45, 8.62254226e+45}}; + {5.45801029e+45, 7.04027628e+45, 8.62254226e+45} + }; EXPECT_TRUE(allclose(t3res, t3expected)); xarray t4arg_0 = {{-2., 1., 3.}, {3., 2., 1.}, {1., 2., 5.}}; @@ -48,14 +49,16 @@ namespace xt xarray t4expected = { {0.09259259, -0.09259259, 0.01851852}, {0.35185185, 0.64814815, -0.46296296}, - {-0.2037037, -0.2962963, 0.25925926}}; + {-0.2037037, -0.2962963, 0.25925926} + }; EXPECT_TRUE(allclose(t4res, t4expected)); auto t5res = xt::linalg::matrix_power(t4arg_0, -13); xarray t5expected = { {-0.02119919, -0.02993041, 0.02400524}, {0.15202629, 0.21469317, -0.17217602}, - {-0.0726041, -0.10253451, 0.08222825}}; + {-0.0726041, -0.10253451, 0.08222825} + }; EXPECT_TRUE(allclose(t5res, t5expected)); } @@ -74,7 +77,8 @@ namespace xt xarray> arg_0 = { {0.95368636 + 0.32324664i, 0.49936872 + 0.22164004i, 0.30452434 + 0.78922905i}, {0.84118920 + 0.59652768i, 0.42052057 + 0.97211559i, 0.19916742 + 0.83068058i}, - {0.67065616 + 0.56830636i, 0.00268706 + 0.29410473i, 0.69147455 + 0.7052149i}}; + {0.67065616 + 0.56830636i, 0.00268706 + 0.29410473i, 0.69147455 + 0.7052149i} + }; auto res = linalg::det(arg_0); auto res_i = std::imag(res); @@ -88,7 +92,8 @@ namespace xt xarray> arg_0 = { {0.13373658 + 0.43025551i, 0.42593478 + 0.17539337i, 0.18840853 + 0.24669458i}, {0.82800224 + 0.11797823i, 0.40310379 + 0.14037109i, 0.88204561 + 0.96870283i}, - {0.35427657 + 0.1233739i, 0.22740960 + 0.94019582i, 0.05410180 + 0.86462543i}}; + {0.35427657 + 0.1233739i, 0.22740960 + 0.94019582i, 0.05410180 + 0.86462543i} + }; auto resc = linalg::slogdet(arg_0); auto sc = std::get<0>(resc); auto sl = std::real(std::get<1>(resc)); @@ -102,7 +107,8 @@ namespace xt xarray arg_b = { {0.20009016, 0.33997118, 0.74433611}, {0.52721448, 0.2449798, 0.49085606}, - {0.49757477, 0.97304175, 0.05011255}}; + {0.49757477, 0.97304175, 0.05011255} + }; auto res = linalg::slogdet(arg_b); double expected_0 = 1.0; double expected_1 = -1.3017524147193602; @@ -121,12 +127,14 @@ namespace xt xarray expected_0 = { {-0.13511895, 0.90281571, 0.40824829}, {-0.49633514, 0.29493179, -0.81649658}, - {-0.85755134, -0.31295213, 0.40824829}}; + {-0.85755134, -0.31295213, 0.40824829} + }; xarray expected_1 = {1.42267074e+01, 1.26522599e+00, 5.89938022e-16}; xarray expected_2 = { {-0.4663281, -0.57099079, -0.67565348}, {-0.78477477, -0.08545673, 0.61386131}, - {-0.40824829, 0.81649658, -0.40824829}}; + {-0.40824829, 0.81649658, -0.40824829} + }; EXPECT_TRUE(allclose(std::get<0>(res), expected_0)); EXPECT_TRUE(allclose(std::get<1>(res), expected_1)); @@ -172,7 +180,8 @@ namespace xt xarray expected_1 = { {-0.33220683, -0.93041946, -0.15478453}, {-0.66309119, 0.34708777, -0.66320446}, - {-0.67078216, 0.11768479, 0.73225787}}; + {-0.67078216, 0.11768479, 0.73225787} + }; auto vals = std::get<0>(res); auto vecs = std::get<1>(res); EXPECT_TRUE(allclose(expected_0, vals)); @@ -190,7 +199,8 @@ namespace xt EXPECT_TRUE(allclose(complexpected_0, cmvals)); xarray, layout_type::column_major> complexpected_1 = { {-0.92387953 + 0.i, -0.38268343 + 0.i}, - {0.00000000 + 0.38268343i, 0.00000000 - 0.92387953i}}; + {0.00000000 + 0.38268343i, 0.00000000 - 0.92387953i} + }; EXPECT_TRUE(allclose(imag(complexpected_1), imag(cmvecs))); EXPECT_TRUE(allclose(real(complexpected_1), real(cmvecs))); @@ -209,49 +219,35 @@ namespace xt {1.07487297, -0.26417364, 1.82998799, 0.97985789, -0.74820612, -0.75097366}, {0.91375249, 1.14211989, -0.23055478, -0.48264987, -0.4591723, 0.83185472}, {0.05318152, -0.30014836, 1.68456715, 0.07388112, 0.0607432, -0.51529535}, - {-1.36227295, -0.12015569, -0.45599178, -1.07135129, -0.27405687, 0.50177945}}; + {-1.36227295, -0.12015569, -0.45599178, -1.07135129, -0.27405687, 0.50177945} + }; auto res = xt::linalg::pinv(arg_0); xarray expected = { - {-0.12524671, - -0.41299325, - 0.09108576, - -0.07346514, - -0.29603324, - -0.1702256, - 0.28503799, - -0.08979346, - -0.29415286}, - {0.38896886, 0.28355746, -0.26406943, 0.36828839, 0.29271933, 0.19368219, -0.11433648, 0.05102866, 0.11050527}, - {0.05803293, -0.11353613, 0.08736754, -0.20744326, 0.21647828, 0.11201996, 0.26187311, 0.24550066, 0.13766844}, - {0.01037663, 0.15710257, -0.24647364, -0.06611398, 0.0482986, 0.21835662, -0.18763349, 0.01747553, -0.02066564}, - {-0.23681835, - -0.44982904, - 0.14998634, - 0.07608489, - 0.04751429, - -0.27436861, - 0.18626331, - -0.01520331, - -0.18144752}, - {-0.25580995, - -0.33606213, - 0.10424938, - -0.64100544, - 0.01084618, - -0.28963573, - 0.89629794, - 0.12987378, - 0.22451995}}; + {-0.12524671, -0.41299325, 0.09108576, -0.07346514, -0.29603324, -0.1702256, 0.28503799, -0.08979346, -0.29415286 + }, + {0.38896886, 0.28355746, -0.26406943, 0.36828839, 0.29271933, 0.19368219, -0.11433648, 0.05102866, 0.11050527 + }, + {0.05803293, -0.11353613, 0.08736754, -0.20744326, 0.21647828, 0.11201996, 0.26187311, 0.24550066, 0.13766844 + }, + {0.01037663, 0.15710257, -0.24647364, -0.06611398, 0.0482986, 0.21835662, -0.18763349, 0.01747553, -0.02066564 + }, + {-0.23681835, -0.44982904, 0.14998634, 0.07608489, 0.04751429, -0.27436861, 0.18626331, -0.01520331, -0.18144752 + }, + {-0.25580995, -0.33606213, 0.10424938, -0.64100544, 0.01084618, -0.28963573, 0.89629794, 0.12987378, 0.22451995 + } + }; EXPECT_TRUE(allclose(expected, res)); xarray> cmpl_arg_0 = { {-0.32865615 + 1.56868725i, 0.28804396 + 0.52266479i}, {-1.29703842 + 0.34647524i, -2.14982936 + 0.31425111i}, - {-0.69224750 - 1.36725801i, 2.22948403 + 1.4612309i}}; + {-0.69224750 - 1.36725801i, 2.22948403 + 1.4612309i} + }; auto cmpl_res = xt::linalg::pinv(cmpl_arg_0); xarray> cmpl_expected = { {-0.06272312 - 0.24840107i, -0.20530381 - 0.00548715i, -0.14179276 + 0.16337684i}, - {0.05975312 - 0.0502577i, -0.17431091 - 0.05525696i, 0.16047967 - 0.14140846i}}; + {0.05975312 - 0.0502577i, -0.17431091 - 0.05525696i, 0.16047967 - 0.14140846i} + }; EXPECT_TRUE(allclose(real(cmpl_expected), real(cmpl_res))); EXPECT_TRUE(allclose(imag(cmpl_expected), imag(cmpl_res))); } @@ -274,7 +270,8 @@ namespace xt {0.06817001, 0.50274712, -0.36802027, -0.93123204}, {-0.5990272, -0.67439921, -0.09397038, -1.55915724}, {2.22694395, 0.59099048, -0.43162172, 0.19410077}, - {0.41859591, 1.68555153, 1.82660739, 1.24427635}}; + {0.41859591, 1.68555153, 1.82660739, 1.24427635} + }; auto res1 = xt::linalg::norm(arg_0, 1); auto res2 = xt::linalg::norm(arg_0, linalg::normorder::frob); auto res3 = xt::linalg::norm(arg_0, linalg::normorder::inf); @@ -298,7 +295,8 @@ namespace xt xarray> cmplarg_0 = { {0.40101756 + 0.71233018i, 0.62731701 + 0.42786349i, 0.32415089 + 0.2977805i}, {0.24475928 + 0.49208478i, 0.69475518 + 0.74029639i, 0.59390240 + 0.35772892i}, - {0.63179202 + 0.41720995i, 0.44025718 + 0.65472131i, 0.08372648 + 0.37380143i}}; + {0.63179202 + 0.41720995i, 0.44025718 + 0.65472131i, 0.08372648 + 0.37380143i} + }; auto cmplres1 = xt::linalg::norm(cmplarg_0, 1); auto cmplres2 = xt::linalg::norm(cmplarg_0, linalg::normorder::frob); auto cmplres3 = xt::linalg::norm(cmplarg_0, linalg::normorder::inf); @@ -338,7 +336,8 @@ namespace xt xarray> arg_1 = { 0.23451288 + 0.77700444i, 0.98799529 + 0.02798196i, - 0.76599595 + 0.17390652i}; + 0.76599595 + 0.17390652i + }; EXPECT_NEAR(2.58550383197, xt::linalg::norm(arg_1, 1), 1e-06); EXPECT_NEAR(1.50088078633, xt::linalg::norm(arg_1, 2), 1e-06); EXPECT_NEAR(1.25673399279, xt::linalg::norm(arg_1, 3), 1e-06); @@ -362,13 +361,15 @@ namespace xt 0.98799529 + 0.15408224i, 0.76599595 + 0.07708648i, 0.77700444 + 0.8898657i, - 0.02798196 + 0.7503787i}; + 0.02798196 + 0.7503787i + }; xarray> carg_1 = { 0.17390652 + 0.23451288i, 0.15408224 + 0.98799529i, 0.07708648 + 0.76599595i, 0.88986570 + 0.77700444i, - 0.75037870 + 0.02798196i}; + 0.75037870 + 0.02798196i + }; auto res_c = xt::linalg::vdot(carg_0, carg_1); EXPECT_NEAR(1.9289808794290355, std::real(res_c), 1e-06); @@ -391,7 +392,8 @@ namespace xt {6, 0, 12, 8, 14, 18, 0, 36, 24, 42, 6, 0, 12, 8, 14}, {12, 14, 2, 10, 14, 36, 42, 6, 30, 42, 12, 14, 2, 10, 14}, {12, 0, 24, 16, 28, 12, 0, 24, 16, 28, 18, 0, 36, 24, 42}, - {24, 28, 4, 20, 28, 24, 28, 4, 20, 28, 36, 42, 6, 30, 42}}; + {24, 28, 4, 20, 28, 24, 28, 4, 20, 28, 36, 42, 6, 30, 42} + }; EXPECT_EQ(expected, res); } @@ -433,7 +435,8 @@ namespace xt xarray erawR = { {-1.00444014e+01, 0.00000000e+00, 6.74440143e-01, 2.24813381e-01}, {-9.58743044e+00, -1.25730337e+01, -6.22814365e-03, 3.37562246e-01}, - {-1.29027101e+01, -7.34080303e+00, -4.07831856e+00, -5.76331089e-01}}; + {-1.29027101e+01, -7.34080303e+00, -4.07831856e+00, -5.76331089e-01} + }; xarray eTau = {1.32854123, 1.79535299, 1.50132395}; @@ -472,7 +475,8 @@ namespace xt xarray, layout_type::column_major> cel_0 = { {-0.40425532 - 0.38723404i, -0.61702128 - 0.44680851i}, - {1.44680851 + 1.02765957i, 2.51063830 + 0.95744681i}}; + {1.44680851 + 1.02765957i, 2.51063830 + 0.95744681i} + }; xarray cel_1 = {16.11787234, 2.68085106}; int cel_2 = 2; xarray cel_3 = {5.01295356, 1.36758789}; @@ -518,14 +522,16 @@ namespace xt xarray arg_0 = { {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, - {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}}}; + {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}} + }; xarray arg_1 = { {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}}, - {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}}}; + {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}} + }; auto res1 = xt::linalg::dot(arg_0, arg_1); xarray expected1 = { @@ -539,7 +545,8 @@ namespace xt {{123, 162, 201}, {474, 513, 552}, {825, 864, 903}}, - {{150, 198, 246}, {582, 630, 678}, {1014, 1062, 1110}}}}; + {{150, 198, 246}, {582, 630, 678}, {1014, 1062, 1110}}} + }; EXPECT_TRUE(allclose(expected1, res1)); @@ -561,7 +568,8 @@ namespace xt {{204, 270, 336}, {798, 864, 930}}, - {{231, 306, 381}, {906, 981, 1056}}}}; + {{231, 306, 381}, {906, 981, 1056}}} + }; EXPECT_TRUE(allclose(expected2, res2)); diff --git a/test/test_lstsq.cpp b/test/test_lstsq.cpp index 8e150d6..8a2ea4e 100644 --- a/test/test_lstsq.cpp +++ b/test/test_lstsq.cpp @@ -39,7 +39,8 @@ namespace xt {0.0580836121681995, 0.8661761457749352, 0.6011150117432088}, {0.7080725777960455, 0.0205844942958024, 0.9699098521619943}, {0.8324426408004217, 0.2123391106782762, 0.1818249672071006}, - {0.1834045098534338, 0.3042422429595377, 0.5247564316322378}}; + {0.1834045098534338, 0.3042422429595377, 0.5247564316322378} + }; // py_b xarray py_b = {1., 1., 1., 1., 1., 1.}; // py_res0 = np.linalg.lstsq(a, b)[0] @@ -68,7 +69,8 @@ namespace xt xarray py_a = { {0.4319450186421158, 0.2912291401980419, 0.6118528947223795}, {0.1394938606520418, 0.2921446485352182, 0.3663618432936917}, - {0.4560699842170359, 0.7851759613930136, 0.1996737821583597}}; + {0.4560699842170359, 0.7851759613930136, 0.1996737821583597} + }; // py_b xarray py_b = {1., 1., 1.}; // py_res0 = np.linalg.lstsq(a, b)[0] @@ -98,14 +100,16 @@ namespace xt xarray py_a = { {0.5142344384136116, 0.5924145688620425, 0.0464504127199977}, {0.6075448519014384, 0.1705241236872915, 0.0650515929852795}, - {0.9488855372533332, 0.9656320330745594, 0.8083973481164611}}; + {0.9488855372533332, 0.9656320330745594, 0.8083973481164611} + }; // py_b xarray py_b = {{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}}; // py_res0 = np.linalg.lstsq(a, b)[0] xarray py_res0 = { {1.6749237812267237, 1.6749237812267237, 1.6749237812267237}, {0.3213797243357512, 0.3213797243357512, 0.3213797243357512}, - {-1.1128753832544371, -1.1128753832544371, -1.1128753832544371}}; + {-1.1128753832544371, -1.1128753832544371, -1.1128753832544371} + }; // py_res1 = np.linalg.lstsq(a, b)[1] xarray py_res1 = {}; // py_res2 = np.linalg.lstsq(a, b)[2] @@ -130,16 +134,13 @@ namespace xt // py_a xarray py_a = { {0.3046137691733707, 0.0976721140063839, 0.6842330265121569, 0.4401524937396013, 0.1220382348447788}, - {0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169, 0.662522284353982}}; + {0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169, 0.662522284353982} + }; // py_b xarray py_b = {1., 1.}; // py_res0 = np.linalg.lstsq(a, b)[0] - xarray py_res0 = { - 0.3137661125421979, - 0.183749537801855, - 0.8404557593671863, - 0.7586648365305537, - -0.1845363594995904}; + xarray py_res0 = + {0.3137661125421979, 0.183749537801855, 0.8404557593671863, 0.7586648365305537, -0.1845363594995904}; // py_res1 = np.linalg.lstsq(a, b)[1] xarray py_res1 = {}; // py_res2 = np.linalg.lstsq(a, b)[2] @@ -163,11 +164,10 @@ namespace xt // py_a xarray py_a = { {0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527, 0.9695846277645586}, - {0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851, 0.9218742350231168}}; + {0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851, 0.9218742350231168} + }; // py_b - xarray py_b = { - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}}; + xarray py_b = {{1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}}; // py_res0 = np.linalg.lstsq(a, b)[0] xarray py_res0 = { {-0.0723929848964098, @@ -219,7 +219,8 @@ namespace xt 0.882719722037499, 0.882719722037499, 0.882719722037499, - 0.882719722037499}}; + 0.882719722037499} + }; // py_res1 = np.linalg.lstsq(a, b)[1] xarray py_res1 = {}; // py_res2 = np.linalg.lstsq(a, b)[2] @@ -251,7 +252,8 @@ namespace xt {0.7296061783380641, 0.6375574713552131, 0.8872127425763265, 0.4722149251619493, 0.1195942459383017}, {0.713244787222995, 0.7607850486168974, 0.5612771975694962, 0.770967179954561, 0.4937955963643907}, {0.5227328293819941, 0.4275410183585496, 0.0254191267440952, 0.1078914269933045, 0.0314291856867343}, - {0.6364104112637804, 0.3143559810763267, 0.5085706911647028, 0.907566473926093, 0.2492922291488749}}; + {0.6364104112637804, 0.3143559810763267, 0.5085706911647028, 0.907566473926093, 0.2492922291488749} + }; // py_b xarray py_b = { {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, @@ -263,7 +265,8 @@ namespace xt {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}}; + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.} + }; // py_res0 = np.linalg.lstsq(a, b)[0] xarray py_res0 = { {0.7695214798127482, 0.7695214798127482, 0.7695214798127482, 0.7695214798127482, @@ -290,14 +293,16 @@ namespace xt 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, - 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307}}; + 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307} + }; // py_res1 = np.linalg.lstsq(a, b)[1] - xarray py_res1 = { - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, 0.6683875034141331}; + xarray py_res1 = {0.6683875034141331, 0.6683875034141331, 0.6683875034141331, + 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, + 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, + 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, + 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, + 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, + 0.6683875034141331, 0.6683875034141331}; // py_res2 = np.linalg.lstsq(a, b)[2] int py_res2 = 5; // py_res3 = np.linalg.lstsq(a, b)[3] diff --git a/test/test_qr.cpp b/test/test_qr.cpp index 41e0733..11667a0 100644 --- a/test/test_qr.cpp +++ b/test/test_qr.cpp @@ -42,7 +42,8 @@ namespace xt {0.0580836121681995, 0.8661761457749352, 0.6011150117432088}, {0.7080725777960455, 0.0205844942958024, 0.9699098521619943}, {0.8324426408004217, 0.2123391106782762, 0.1818249672071006}, - {0.1834045098534338, 0.3042422429595377, 0.5247564316322378}}; + {0.1834045098534338, 0.3042422429595377, 0.5247564316322378} + }; // py_resq1_h = res_q1[0] xarray py_resq1_h = { {-1.3152987216651169, @@ -62,7 +63,8 @@ namespace xt 0.7854784971183756, -0.8184018010449023, 0.3355103841692941, - -0.2743559826773574}}; + -0.2743559826773574} + }; // py_resq1_tau = res_q1[1] xarray py_resq1_tau = {1.2847566964660388, 1.3124991842889797, 1.0766465015522177}; @@ -102,12 +104,9 @@ namespace xt 0.1040507467269481, 0.5878955555305321, 0.0326957112268427}, - {-0.1394394344284399, - 0.1841243791750922, - 0.31850193596774, - -0.3303532438685529, - 0.1575155429538277, - 0.8433664457979998}}; + {-0.1394394344284399, 0.1841243791750922, 0.31850193596774, -0.3303532438685529, 0.1575155429538277, 0.8433664457979998 + } + }; // py_resq2_r_cmpl = res_q2[1] xarray py_resq2_r_cmpl = { {-1.3152987216651169, -0.567877094797874, -1.0163710885529547}, @@ -115,7 +114,8 @@ namespace xt {0., 0., 0.7854784971183756}, {0., 0., 0.}, {0., 0., 0.}, - {0., 0., 0.}}; + {0., 0., 0.} + }; auto res2 = linalg::qr(py_a, linalg::qrmode::complete); EXPECT_TRUE(allclose(std::get<0>(res2), py_resq2_q_cmpl)); @@ -128,12 +128,14 @@ namespace xt {-0.0441600156766425, 0.6881200538051699, 0.0760152664601146}, {-0.538335943107778, -0.2332659103773061, 0.7525061466150679}, {-0.6328924578795164, -0.1203177215897514, -0.4769214096589269}, - {-0.1394394344284399, 0.1841243791750922, 0.31850193596774}}; + {-0.1394394344284399, 0.1841243791750922, 0.31850193596774} + }; // py_resq3_r_cmpl = res_q3[1] xarray py_resq3_r_cmpl = { {-1.3152987216651169, -0.567877094797874, -1.0163710885529547}, {0., 1.2223138676385652, 0.7215655008695085}, - {0., 0., 0.7854784971183756}}; + {0., 0., 0.7854784971183756} + }; auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); EXPECT_TRUE(allclose(std::get<0>(res3), py_resq3_q_cmpl)); @@ -143,7 +145,8 @@ namespace xt xarray py_resq4_r_r = { {-1.3152987216651169, -0.567877094797874, -1.0163710885529547}, {0., 1.2223138676385652, 0.7215655008695085}, - {0., 0., 0.7854784971183756}}; + {0., 0., 0.7854784971183756} + }; auto res4 = linalg::qr(py_a, linalg::qrmode::r); EXPECT_TRUE(allclose(std::get<1>(res4), py_resq4_r_r)); @@ -209,19 +212,27 @@ namespace xt 0.2809345096873808, 0.5426960831582485, 0.1409242249747626, - 0.8021969807540397}}; + 0.8021969807540397} + }; // py_resq1_h = res_q1[0] xarray py_resq1_h = { {-1.1430852952870696, 0.3761289948662397, 0.4344253062693247, 0.3471109568548026, 0.0287151863113738}, - {-0.4988738747365853, 0.4145384440977922, -0.1456730968857621, 0.1343802288038163, -0.4549175132696516}, - {-1.0982282164248067, 0.0432498341745755, 0.8009723247566577, -0.2697221220857602, -0.2118640849148783}, - {-0.8189559577243967, 0.2159221672678357, 0.2467828455102148, -0.4358731022610104, 0.0126894274012747}, - {-0.6468222288756241, 0.5399745339753013, 0.9011434603476536, -0.3516828694145329, 0.1205612964483228}, - {-1.6166030169206462, 0.0627336303098124, 0.1745159258713335, -0.1676233275811677, 0.3369911999240203}, + {-0.4988738747365853, 0.4145384440977922, -0.1456730968857621, 0.1343802288038163, -0.4549175132696516 + }, + {-1.0982282164248067, 0.0432498341745755, 0.8009723247566577, -0.2697221220857602, -0.2118640849148783 + }, + {-0.8189559577243967, 0.2159221672678357, 0.2467828455102148, -0.4358731022610104, 0.0126894274012747 + }, + {-0.6468222288756241, 0.5399745339753013, 0.9011434603476536, -0.3516828694145329, 0.1205612964483228 + }, + {-1.6166030169206462, 0.0627336303098124, 0.1745159258713335, -0.1676233275811677, 0.3369911999240203 + }, {-1.1247642047094615, -0.1631138338388988, 0.4469666475320985, 0.229673631977487, 0.3155802843315489}, {-1.5746170823854417, 0.2876936477590398, 0.5186696050660637, 0.0972324032495854, 0.1124970816045023}, - {-0.4678059691956431, 0.0924634343088705, -0.0398310260167535, 0.1199094213119632, 0.1189824829973467}, - {-0.6817147826175952, 0.8209704648352938, 0.1936105292921998, 0.1556371881989978, 0.1610633542281174}}; + {-0.4678059691956431, 0.0924634343088705, -0.0398310260167535, 0.1199094213119632, 0.1189824829973467 + }, + {-0.6817147826175952, 0.8209704648352938, 0.1936105292921998, 0.1556371881989978, 0.1610633542281174} + }; // py_resq1_tau = res_q1[1] xarray py_resq1_tau = {1.3778764545594464, 1.6048419481909388, 1.7894907284949315, 1.9996780087119976, 0.}; @@ -231,11 +242,16 @@ namespace xt EXPECT_TRUE(allclose(std::get<1>(res1), py_resq1_tau)); // py_resq2_q_cmpl = res_q2[0] xarray py_resq2_q_cmpl = { - {-0.3778764545594464, 0.2477850983490842, 0.2323946032026168, 0.6442783657634198, -0.5715855718413764}, - {-0.5182592859033026, -0.5116427882060655, 0.0755411199296714, 0.3718390470559857, 0.5706647283090043}, - {-0.5985844007732788, 0.3414264138444295, -0.6868036045376356, -0.2311050279755064, 0.0039992397751236}, - {-0.4782760145698324, -0.1296501056095487, 0.5617369601749853, -0.6259003290333343, -0.217125009365484}, - {-0.0395659791067296, 0.737185903526121, 0.3912015899373974, 0.0384753772446496, 0.5481536630508244}}; + {-0.3778764545594464, 0.2477850983490842, 0.2323946032026168, 0.6442783657634198, -0.5715855718413764 + }, + {-0.5182592859033026, -0.5116427882060655, 0.0755411199296714, 0.3718390470559857, 0.5706647283090043 + }, + {-0.5985844007732788, 0.3414264138444295, -0.6868036045376356, -0.2311050279755064, 0.0039992397751236 + }, + {-0.4782760145698324, -0.1296501056095487, 0.5617369601749853, -0.6259003290333343, -0.217125009365484 + }, + {-0.0395659791067296, 0.737185903526121, 0.3912015899373974, 0.0384753772446496, 0.5481536630508244} + }; // py_resq2_r_cmpl = res_q2[1] xarray py_resq2_r_cmpl = { {-1.1430852952870696, @@ -287,7 +303,8 @@ namespace xt 0.3155802843315489, 0.1124970816045023, 0.1189824829973467, - 0.1610633542281174}}; + 0.1610633542281174} + }; auto res2 = linalg::qr(py_a, linalg::qrmode::complete); EXPECT_TRUE(allclose(std::get<0>(res2), py_resq2_q_cmpl)); @@ -295,11 +312,16 @@ namespace xt // py_resq3_q_cmpl = res_q3[0] xarray py_resq3_q_cmpl = { - {-0.3778764545594464, 0.2477850983490842, 0.2323946032026168, 0.6442783657634198, -0.5715855718413764}, - {-0.5182592859033026, -0.5116427882060655, 0.0755411199296714, 0.3718390470559857, 0.5706647283090043}, - {-0.5985844007732788, 0.3414264138444295, -0.6868036045376356, -0.2311050279755064, 0.0039992397751236}, - {-0.4782760145698324, -0.1296501056095487, 0.5617369601749853, -0.6259003290333343, -0.217125009365484}, - {-0.0395659791067296, 0.737185903526121, 0.3912015899373974, 0.0384753772446496, 0.5481536630508244}}; + {-0.3778764545594464, 0.2477850983490842, 0.2323946032026168, 0.6442783657634198, -0.5715855718413764 + }, + {-0.5182592859033026, -0.5116427882060655, 0.0755411199296714, 0.3718390470559857, 0.5706647283090043 + }, + {-0.5985844007732788, 0.3414264138444295, -0.6868036045376356, -0.2311050279755064, 0.0039992397751236 + }, + {-0.4782760145698324, -0.1296501056095487, 0.5617369601749853, -0.6259003290333343, -0.217125009365484 + }, + {-0.0395659791067296, 0.737185903526121, 0.3912015899373974, 0.0384753772446496, 0.5481536630508244} + }; // py_resq3_r_cmpl = res_q3[1] xarray py_resq3_r_cmpl = { {-1.1430852952870696, @@ -351,7 +373,8 @@ namespace xt 0.3155802843315489, 0.1124970816045023, 0.1189824829973467, - 0.1610633542281174}}; + 0.1610633542281174} + }; auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); EXPECT_TRUE(allclose(std::get<0>(res3), py_resq3_q_cmpl)); @@ -408,7 +431,8 @@ namespace xt 0.3155802843315489, 0.1124970816045023, 0.1189824829973467, - 0.1610633542281174}}; + 0.1610633542281174} + }; auto res4 = linalg::qr(py_a, linalg::qrmode::r); EXPECT_TRUE(allclose(std::get<1>(res4), py_resq4_r_r)); diff --git a/test/test_tensordot.cpp b/test/test_tensordot.cpp index 0c35990..240818f 100644 --- a/test/test_tensordot.cpp +++ b/test/test_tensordot.cpp @@ -99,7 +99,8 @@ namespace xt xarray a = { {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}}, {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}}, - {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}}}; + {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}} + }; xarray b = xt::ones({2, 3, 2, 3}); auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1}); xarray e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}}; @@ -115,7 +116,8 @@ namespace xt xarray a = { {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}}, {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}}, - {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}}}; + {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}} + }; xarray b = xt::ones({2, 3, 2, 3}); auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1}); xarray e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}}; @@ -131,7 +133,8 @@ namespace xt xarray a = { {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}}, {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}}, - {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}}}; + {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}} + }; xarray b = xt::ones({2, 3, 2, 3}); auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1}); xarray e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}};