diff --git a/.github/workflows/linux.yml b/.github/workflows/linux.yml index 2d0c2ac..82cc554 100644 --- a/.github/workflows/linux.yml +++ b/.github/workflows/linux.yml @@ -74,7 +74,7 @@ jobs: run: micromamba install 'openblas==0.3.29=pthreads*' blas-devel - name: Configure using CMake - run: cmake -Bbuild -DDOWNLOAD_GTEST=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DCMAKE_SYSTEM_IGNORE_PATH=/usr/lib + run: cmake -Bbuild -DBUILD_TESTS=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DCMAKE_SYSTEM_IGNORE_PATH=/usr/lib - name: Build working-directory: build diff --git a/.github/workflows/osx.yml b/.github/workflows/osx.yml index 6e009dc..b4e8ca2 100644 --- a/.github/workflows/osx.yml +++ b/.github/workflows/osx.yml @@ -39,7 +39,7 @@ jobs: run: micromamba install 'openblas==0.3.29=openmp*' blas-devel - name: Configure using CMake - run: cmake -Bbuild -DCMAKE_CXX_STANDARD=17 -DDOWNLOAD_GTEST=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX -DCMAKE_SYSTEM_IGNORE_PATH=/usr/lib + run: cmake -Bbuild -DCMAKE_CXX_STANDARD=17 -DBUILD_TESTS=ON -DCMAKE_PREFIX_PATH=$CONDA_PREFIX -DCMAKE_SYSTEM_IGNORE_PATH=/usr/lib - name: Build working-directory: build diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml index 57ff188..b6efbe1 100644 --- a/.github/workflows/windows.yml +++ b/.github/workflows/windows.yml @@ -51,7 +51,7 @@ jobs: run: micromamba install mkl-devel - name: Configure using CMake - run: cmake -Bbuild -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DDOWNLOAD_GTEST=ON -G Ninja + run: cmake -Bbuild -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX=$CONDA_PREFIX -DBUILD_TESTS=ON -G Ninja - name: Build working-directory: build diff --git a/CMakeLists.txt b/CMakeLists.txt index 8886cda..1c4f5c5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -101,13 +101,6 @@ endif() OPTION(BUILD_TESTS "xtensor-blas test suite" OFF) OPTION(BUILD_BENCHMARK "xtensor-blas test suite" OFF) -OPTION(DOWNLOAD_GTEST "download gtest and build from source" OFF) -OPTION(DOWNLOAD_GBENCHMARK "download google benchmark and build from source" OFF) - -if(DOWNLOAD_GTEST OR GTEST_SRC_DIR) - set(BUILD_TESTS ON) -endif() - if(BUILD_TESTS) enable_testing() include_directories(${XTENSOR_BLAS_INCLUDE_DIR}) diff --git a/environment-dev.yml b/environment-dev.yml index a86cef4..525fbf9 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -4,3 +4,4 @@ channels: dependencies: - cmake - xtensor>=0.26.0,<0.27 +- doctest diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 025233f..9445f66 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -20,6 +20,9 @@ if (CMAKE_CURRENT_SOURCE_DIR STREQUAL CMAKE_SOURCE_DIR) set(XTENSOR_BLAS_INCLUDE_DIR ${xblas_INCLUDE_DIRS}) endif () +find_package(doctest REQUIRED) +find_package(Threads) + if(NOT CMAKE_BUILD_TYPE) message(STATUS "Setting tests build type to Release") set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose the type of build." FORCE) @@ -86,43 +89,6 @@ else() message(FATAL_ERROR "Unsupported compiler: ${CMAKE_CXX_COMPILER_ID}") endif() -if(DOWNLOAD_GTEST OR GTEST_SRC_DIR) - if(DOWNLOAD_GTEST) - # Download and unpack googletest at configure time - configure_file(downloadGTest.cmake.in googletest-download/CMakeLists.txt) - else() - # Copy local source of googletest at configure time - configure_file(copyGTest.cmake.in googletest-download/CMakeLists.txt) - endif() - execute_process(COMMAND ${CMAKE_COMMAND} -G "${CMAKE_GENERATOR}" . - RESULT_VARIABLE result - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) - if(result) - message(FATAL_ERROR "CMake step for googletest failed: ${result}") - endif() - execute_process(COMMAND ${CMAKE_COMMAND} --build . - RESULT_VARIABLE result - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/googletest-download ) - if(result) - message(FATAL_ERROR "Build step for googletest failed: ${result}") - endif() - - set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) - - # Add googletest directly to our build. This defines - # the gtest and gtest_main targets. - add_subdirectory(${CMAKE_CURRENT_BINARY_DIR}/googletest-src - ${CMAKE_CURRENT_BINARY_DIR}/googletest-build EXCLUDE_FROM_ALL) - - set(GTEST_INCLUDE_DIRS "${gtest_SOURCE_DIR}/include") - set(GTEST_BOTH_LIBRARIES gtest_main gtest) -else() - find_package(GTest REQUIRED) -endif() - -find_package(Threads) - -include_directories(${GTEST_INCLUDE_DIRS} SYSTEM) include_directories(${XTENSOR_INCLUDE_DIR}) include_directories(${XBLAS_INCLUDE_DIR}) @@ -154,11 +120,7 @@ set(XTENSOR_BLAS_TESTS ) add_executable(test_xtensor_blas ${XTENSOR_BLAS_TESTS} ${XTENSOR_BLAS_HEADERS} ${XTENSOR_HEADERS}) -if(DOWNLOAD_GTEST OR GTEST_SRC_DIR) - add_dependencies(test_xtensor_blas gtest_main) -endif() - -target_link_libraries(test_xtensor_blas ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES} ${GTEST_BOTH_LIBRARIES} ${CMAKE_THREAD_LIBS_INIT}) +target_link_libraries(test_xtensor_blas ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES} doctest::doctest ${CMAKE_THREAD_LIBS_INIT}) add_custom_target(xtest COMMAND test_xtensor_blas DEPENDS test_xtensor_blas) add_test(NAME xtest COMMAND test_xtensor_blas) diff --git a/test/copyGTest.cmake.in b/test/copyGTest.cmake.in deleted file mode 100644 index 56ccf6d..0000000 --- a/test/copyGTest.cmake.in +++ /dev/null @@ -1,23 +0,0 @@ -############################################################################ -# Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht # -# Copyright (c) QuantStack # -# # -# Distributed under the terms of the BSD 3-Clause License. # -# # -# The full license is in the file LICENSE, distributed with this software. # -############################################################################ - -cmake_minimum_required(VERSION 2.8.2) - -project(googletest-download NONE) - -include(ExternalProject) -ExternalProject_Add(googletest - URL "${GTEST_SRC_DIR}" - SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" - BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-build" - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - TEST_COMMAND "" -) diff --git a/test/downloadGTest.cmake.in b/test/downloadGTest.cmake.in deleted file mode 100644 index e06bb06..0000000 --- a/test/downloadGTest.cmake.in +++ /dev/null @@ -1,24 +0,0 @@ -############################################################################ -# Copyright (c) Johan Mabille, Sylvain Corlay and Wolf Vollprecht # -# Copyright (c) QuantStack # -# # -# Distributed under the terms of the BSD 3-Clause License. # -# # -# The full license is in the file LICENSE, distributed with this software. # -############################################################################ - -cmake_minimum_required(VERSION 3.29) - -project(googletest-download NONE) - -include(ExternalProject) -ExternalProject_Add(googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG v1.16.0 - SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-src" - BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}/googletest-build" - CONFIGURE_COMMAND "" - BUILD_COMMAND "" - INSTALL_COMMAND "" - TEST_COMMAND "" -) diff --git a/test/main.cpp b/test/main.cpp index c73ccc8..d711c29 100644 --- a/test/main.cpp +++ b/test/main.cpp @@ -6,10 +6,5 @@ * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ -#include "gtest/gtest.h" - -int main(int argc, char* argv[]) -{ - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#include "doctest/doctest.h" diff --git a/test/test_blas.cpp b/test/test_blas.cpp index e1e5a52..764cdda 100644 --- a/test/test_blas.cpp +++ b/test/test_blas.cpp @@ -12,171 +12,176 @@ #include "xtensor/generators/xrandom.hpp" #include "xtensor/views/xview.hpp" -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor-blas/xblas.hpp" #include "xtensor-blas/xlinalg.hpp" namespace xt { - TEST(xblas, matrix_times_vector) + TEST_SUITE("xblas") { - xarray m1 = {{1, 2, 3}, {4, 5, 6}}; - xarray b = {1, 2, 3}; - - auto res = linalg::dot(m1, b); - xarray expected = {14, 32}; - EXPECT_EQ(expected, res); - - xarray next_row = {{7, 8, 9}}; - auto res2 = linalg::dot(concatenate(xtuple(m1, next_row)), b); - xarray expected2 = {14, 32, 50}; - EXPECT_EQ(expected2, res2); - } + TEST_CASE("matrix_times_vector") + { + xarray m1 = {{1, 2, 3}, {4, 5, 6}}; + xarray b = {1, 2, 3}; - TEST(xblas, dot_2d) - { - xarray a = {{1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}}; - xarray b = {5, 4, 3, 2, 1}; - xarray c = {1, 2}; - auto res = linalg::dot(a, b); - xarray expected = {35, 35}; + auto res = linalg::dot(m1, b); + xarray expected = {14, 32}; + CHECK_EQ(expected, res); - auto res_ca = linalg::dot(c, a); - xarray expected_ca = {3, 6, 9, 12, 15}; + xarray next_row = {{7, 8, 9}}; + auto res2 = linalg::dot(concatenate(xtuple(m1, next_row)), b); + xarray expected2 = {14, 32, 50}; + CHECK_EQ(expected2, res2); + } - EXPECT_EQ(expected, res); - EXPECT_EQ(expected_ca, res_ca); - } + TEST_CASE("dot_2d") + { + xarray a = {{1, 2, 3, 4, 5}, {1, 2, 3, 4, 5}}; + xarray b = {5, 4, 3, 2, 1}; + xarray c = {1, 2}; + auto res = linalg::dot(a, b); + xarray expected = {35, 35}; - TEST(xblas, matrix_matrix) - { - xarray a = arange(3 * 3); - a.reshape({3, 3}); - xarray b = arange(5 * 3); - b.reshape({3, 5}); - - auto ab = linalg::dot(a, b); - xarray ab_expected = {{25, 28, 31, 34, 37}, {70, 82, 94, 106, 118}, {115, 136, 157, 178, 199}}; - EXPECT_EQ(ab_expected, ab); - } + auto res_ca = linalg::dot(c, a); + xarray expected_ca = {3, 6, 9, 12, 15}; - TEST(xblas, view_dot) - { - xarray a = {1, 2, 3, 4, 5}; - xarray b = {5, 4, 3, 2, 1}; - auto res = linalg::dot(a, b); + CHECK_EQ(expected, res); + CHECK_EQ(expected_ca, res_ca); + } - xarray expected = {35}; - EXPECT_EQ(expected, res); + TEST_CASE("matrix_matrix") + { + xarray a = arange(3 * 3); + a.reshape({3, 3}); + xarray b = arange(5 * 3); + b.reshape({3, 5}); + + auto ab = linalg::dot(a, b); + xarray ab_expected = {{25, 28, 31, 34, 37}, {70, 82, 94, 106, 118}, {115, 136, 157, 178, 199}}; + CHECK_EQ(ab_expected, ab); + } - xarray m1{{1, 2, 3}, {4, 5, 6}}; + TEST_CASE("view_dot") + { + xarray a = {1, 2, 3, 4, 5}; + xarray b = {5, 4, 3, 2, 1}; + auto res = linalg::dot(a, b); - xarray c = {1, 2}; - auto res2 = xt::linalg::dot(xt::view(m1, xt::all(), 1), c); - xarray expected2 = {12}; - EXPECT_EQ(expected2, res2); - } + xarray expected = {35}; + CHECK_EQ(expected, res); - TEST(xblas, norm) - { - auto a = linalg::norm(xt::arange(15), 1); - auto b = linalg::norm(xt::arange(15), 2); - xarray c = {6, 4, 2, 1}; - auto res = linalg::norm(c); - - EXPECT_EQ(a, 105.0); - EXPECT_NEAR(b, 31.859064644147981, 1e-6); - EXPECT_NEAR(res, 7.5498344352707498, 1e-6); - } + xarray m1{{1, 2, 3}, {4, 5, 6}}; - TEST(xblas, normFloat) - { - auto a = linalg::norm(xt::arange(15), 1); - auto b = linalg::norm(xt::arange(15), 2); - xarray c = {6, 4, 2, 1}; - auto res = linalg::norm(c); - - EXPECT_EQ(a, 105.0); - EXPECT_NEAR(b, 31.859064644147981, 1e-6); - EXPECT_NEAR(res, 7.5498344352707498, 1e-6); - } + xarray c = {1, 2}; + auto res2 = xt::linalg::dot(xt::view(m1, xt::all(), 1), c); + xarray expected2 = {12}; + CHECK_EQ(expected2, res2); + } - TEST(xblas, outer) - { - xarray a = {1, 1, 1}; + TEST_CASE("norm") + { + auto a = linalg::norm(xt::arange(15), 1); + auto b = linalg::norm(xt::arange(15), 2); + xarray c = {6, 4, 2, 1}; + auto res = linalg::norm(c); + + CHECK_EQ(a, 105.0); + CHECK(b == doctest::Approx(31.859064644147981).epsilon(1e-6)); + CHECK(res == doctest::Approx(7.5498344352707498).epsilon(1e-6)); + // EXPECT_NEAR(b, 31.859064644147981, 1e-6); + // EXPECT_NEAR(res, 7.5498344352707498, 1e-6); + } - xarray b = arange(0, 3); + TEST_CASE("normFloat") + { + auto a = linalg::norm(xt::arange(15), 1); + auto b = linalg::norm(xt::arange(15), 2); + xarray c = {6, 4, 2, 1}; + auto res = linalg::norm(c); + + CHECK_EQ(a, 105.0); + CHECK(b == doctest::Approx(31.859064644147981).epsilon(1e-6)); + CHECK(res == doctest::Approx(7.5498344352707498).epsilon(1e-6)); + } - xarray expected = {{0, 1, 2}, {0, 1, 2}, {0, 1, 2}}; + TEST_CASE("outer") + { + xarray a = {1, 1, 1}; - auto t = linalg::outer(a, b); - auto t2 = linalg::outer(a, xt::arange(0, 3)); - auto t3 = linalg::outer(xt::ones({3}), xt::arange(0, 3)); + xarray b = arange(0, 3); - EXPECT_TRUE(all(equal(expected, t))); - EXPECT_TRUE(all(equal(expected, t2))); - EXPECT_TRUE(all(equal(expected, t3))); - } + xarray expected = {{0, 1, 2}, {0, 1, 2}, {0, 1, 2}}; - TEST(xblas, outer_random) - { - xt::random::seed(123); - xt::xarray expected = xt::random::randn({5}); - xt::random::seed(123); - auto x = xt::random::randn({5}); - auto weights = xt::xarray({1}); // should perform identity + auto t = linalg::outer(a, b); + auto t2 = linalg::outer(a, xt::arange(0, 3)); + auto t3 = linalg::outer(xt::ones({3}), xt::arange(0, 3)); - auto result = linalg::outer(x, weights); + CHECK(all(equal(expected, t))); + CHECK(all(equal(expected, t2))); + CHECK(all(equal(expected, t3))); + } - // shapes are different - for (std::size_t i = 0; i < 5; ++i) + TEST_CASE("outer_random") { - EXPECT_EQ(result.data()[i], expected.data()[i]); + xt::random::seed(123); + xt::xarray expected = xt::random::randn({5}); + xt::random::seed(123); + auto x = xt::random::randn({5}); + auto weights = xt::xarray({1}); // should perform identity + + auto result = linalg::outer(x, weights); + + // shapes are different + for (std::size_t i = 0; i < 5; ++i) + { + CHECK_EQ(result.data()[i], expected.data()[i]); + } } - } - TEST(xblas, nan_result) - { - xt::xarray X = {{1, 2, 3}, {1, 2, 3}}; - - auto M = xt::xarray::from_shape({3, 3}); - M(0, 0) = std::numeric_limits::quiet_NaN(); - M(0, 1) = std::numeric_limits::quiet_NaN(); - xt::blas::gemm(X, X, M, true, false, 1.0, 0.0); - for (std::size_t i = 0; i < M.size(); ++i) + TEST_CASE("nan_result") { - EXPECT_FALSE(std::isnan(M.storage()[i])); + xt::xarray X = {{1, 2, 3}, {1, 2, 3}}; + + auto M = xt::xarray::from_shape({3, 3}); + M(0, 0) = std::numeric_limits::quiet_NaN(); + M(0, 1) = std::numeric_limits::quiet_NaN(); + xt::blas::gemm(X, X, M, true, false, 1.0, 0.0); + for (std::size_t i = 0; i < M.size(); ++i) + { + CHECK_FALSE(std::isnan(M.storage()[i])); + } } - } - TEST(xblas, gemm_transpose) - { - xt::xarray X = {{1, 2, 3}, {1, 2, 3}}; + TEST_CASE("gemm_transpose") + { + xt::xarray X = {{1, 2, 3}, {1, 2, 3}}; - auto M = xt::xarray::from_shape({3, 3}); - auto O = xt::xarray::from_shape({2, 2}); + auto M = xt::xarray::from_shape({3, 3}); + auto O = xt::xarray::from_shape({2, 2}); - xt::blas::gemm(X, X, M, true, false, 1.0, 0.0); - xt::blas::gemm(X, X, O, false, true, 1.0, 0.0); + xt::blas::gemm(X, X, M, true, false, 1.0, 0.0); + xt::blas::gemm(X, X, O, false, true, 1.0, 0.0); - xt::xarray expM = {{2, 4, 6}, {4, 8, 12}, {6, 12, 18}}; + xt::xarray expM = {{2, 4, 6}, {4, 8, 12}, {6, 12, 18}}; - xt::xarray expO = {{14, 14}, {14, 14}}; + xt::xarray expO = {{14, 14}, {14, 14}}; - EXPECT_TRUE(all(equal(expM, M))); - EXPECT_TRUE(all(equal(expO, O))); - } + CHECK(all(equal(expM, M))); + CHECK(all(equal(expO, O))); + } - TEST(xblas, gemv_transpose) - { - xt::xarray X = {{1, 2, 3}, {1, 2, 3}}; - xt::xarray v = {1, 2}; - auto R = xt::xarray::from_shape({3}); + TEST_CASE("gemv_transpose") + { + xt::xarray X = {{1, 2, 3}, {1, 2, 3}}; + xt::xarray v = {1, 2}; + auto R = xt::xarray::from_shape({3}); - xt::blas::gemv(X, v, R, true, 1, 0); + xt::blas::gemv(X, v, R, true, 1, 0); - xt::xarray expR = {3, 6, 9}; + xt::xarray expR = {3, 6, 9}; - EXPECT_TRUE(all(equal(expR, R))); + CHECK(all(equal(expR, R))); + } } } // namespace xt diff --git a/test/test_dot.cpp b/test/test_dot.cpp index 63acf68..ca16319 100644 --- a/test/test_dot.cpp +++ b/test/test_dot.cpp @@ -12,210 +12,214 @@ #include "xtensor/views/xstrided_view.hpp" #include "xtensor/views/xview.hpp" -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor-blas/xlinalg.hpp" namespace xt { - TEST(xdot, matrix_times_vector) + TEST_SUITE("xdot") { - xarray a = xt::ones({1, 4}); - xarray b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {1, 1, 1}}; + TEST_CASE("matrix_times_vector") + { + xarray a = xt::ones({1, 4}); + xarray b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {1, 1, 1}}; - xarray e1 = {{13, 16, 19}}; + xarray e1 = {{13, 16, 19}}; - auto r1 = linalg::dot(a, b); - EXPECT_EQ(e1, r1); + auto r1 = linalg::dot(a, b); + CHECK_EQ(e1, r1); - xarray c = xt::ones({3, 1}); + xarray c = xt::ones({3, 1}); - auto r2 = linalg::dot(b, c); - xarray e2 = {{6, 15, 24, 3}}; - e2.reshape({4, 1}); - EXPECT_EQ(e2, r2); + auto r2 = linalg::dot(b, c); + xarray e2 = {{6, 15, 24, 3}}; + e2.reshape({4, 1}); + CHECK_EQ(e2, r2); - EXPECT_THROW(linalg::dot(b, a), std::runtime_error); - EXPECT_THROW(linalg::dot(c, b), std::runtime_error); - } + CHECK_THROWS_AS(linalg::dot(b, a), std::runtime_error); + CHECK_THROWS_AS(linalg::dot(c, b), std::runtime_error); + } - TEST(xdot, matrix_transpose_times_column) - { - xarray a = xt::ones({2, 4}); - xarray b = xt::ones({2, 1}); - auto r1 = linalg::dot(xt::transpose(a), b); - EXPECT_TRUE(all(equal(r1, 2.0))); - } + TEST_CASE("matrix_transpose_times_column") + { + xarray a = xt::ones({2, 4}); + xarray b = xt::ones({2, 1}); + auto r1 = linalg::dot(xt::transpose(a), b); + CHECK(all(equal(r1, 2.0))); + } - TEST(xdot, matrix_transpose_times_column_cm) - { - xarray a = xt::ones({2, 4}); - xarray b = xt::ones({2, 1}); + TEST_CASE("matrix_transpose_times_column_cm") + { + xarray a = xt::ones({2, 4}); + xarray b = xt::ones({2, 1}); - auto r1 = linalg::dot(xt::transpose(a), b); - EXPECT_TRUE(all(equal(r1, 2.0))); - } + auto r1 = linalg::dot(xt::transpose(a), b); + CHECK(all(equal(r1, 2.0))); + } - TEST(xdot, square_matrix_times_vector) - { - xarray a = {{1, 1, 1}}; - xarray b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; + TEST_CASE("square_matrix_times_vector") + { + xarray a = {{1, 1, 1}}; + xarray b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; - auto r1 = linalg::dot(a, b); + auto r1 = linalg::dot(a, b); - xarray e1 = {{12, 15, 18}}; - EXPECT_EQ(r1, e1); + xarray e1 = {{12, 15, 18}}; + CHECK_EQ(r1, e1); - auto r2 = linalg::dot(b, xt::transpose(a)); - xarray e2 = xarray::from_shape({3, 1}); - e2(0, 0) = 6.f; - e2(1, 0) = 15.f; - e2(2, 0) = 24.f; - EXPECT_EQ(r2, e2); + auto r2 = linalg::dot(b, xt::transpose(a)); + xarray e2 = xarray::from_shape({3, 1}); + e2(0, 0) = 6.f; + e2(1, 0) = 15.f; + e2(2, 0) = 24.f; + CHECK_EQ(r2, e2); - EXPECT_THROW(linalg::dot(b, a), std::runtime_error); - } + CHECK_THROWS_AS(linalg::dot(b, a), std::runtime_error); + } - TEST(xdot, vector_times_vector) - { - xarray a = xt::ones({1, 3}); - xarray b = xt::ones({3, 1}); + TEST_CASE("vector_times_vector") + { + xarray a = xt::ones({1, 3}); + xarray b = xt::ones({3, 1}); - auto r1 = linalg::dot(a, b); + auto r1 = linalg::dot(a, b); - xarray e1 = xarray::from_shape({1, 1}); - e1(0, 0) = 3; + xarray e1 = xarray::from_shape({1, 1}); + e1(0, 0) = 3; - EXPECT_EQ(e1, r1); + CHECK_EQ(e1, r1); - auto r2 = linalg::dot(b, a); - xarray e2 = xt::ones({3, 3}); - EXPECT_EQ(e2, r2); + auto r2 = linalg::dot(b, a); + xarray e2 = xt::ones({3, 3}); + CHECK_EQ(e2, r2); - auto r3 = linalg::dot(b, e1); - EXPECT_EQ(b * 3.f, r3); - } + auto r3 = linalg::dot(b, e1); + CHECK_EQ(b * 3.f, r3); + } - TEST(xdot, matrix_times_1d) - { - xarray a = xt::ones({5, 3}); - xarray b = xt::ones({5}); - xarray c = xt::ones({3}); - auto r1 = linalg::dot(xt::transpose(a), b); + TEST_CASE("matrix_times_1d") + { + xarray a = xt::ones({5, 3}); + xarray b = xt::ones({5}); + xarray c = xt::ones({3}); + auto r1 = linalg::dot(xt::transpose(a), b); - EXPECT_TRUE(all(equal(r1, 5.0))); + CHECK(all(equal(r1, 5.0))); - auto r2 = linalg::dot(c, xt::transpose(a)); - EXPECT_TRUE(all(equal(r2, 3.0))); + auto r2 = linalg::dot(c, xt::transpose(a)); + CHECK(all(equal(r2, 3.0))); - auto r3 = linalg::dot(a, c); - EXPECT_TRUE(all(equal(r3, 3.0))); + auto r3 = linalg::dot(a, c); + CHECK(all(equal(r3, 3.0))); - auto r4 = linalg::dot(c, xt::ones({3, 5})); - EXPECT_TRUE(all(equal(r3, 3.0))); - } + auto r4 = linalg::dot(c, xt::ones({3, 5})); + CHECK(all(equal(r3, 3.0))); + } - TEST(xdot, A_times_A_T) - { - xarray a = xt::ones({5, 3}); + TEST_CASE("A_times_A_T") + { + xarray a = xt::ones({5, 3}); - auto r1 = linalg::dot(a, xt::transpose(a)); - EXPECT_TRUE(all(equal(r1, 3.0))); + auto r1 = linalg::dot(a, xt::transpose(a)); + CHECK(all(equal(r1, 3.0))); - auto r2 = linalg::dot(xt::transpose(a), a); - EXPECT_TRUE(all(equal(r2, 5.0))); - } + auto r2 = linalg::dot(xt::transpose(a), a); + CHECK(all(equal(r2, 5.0))); + } - TEST(xdot, matrix_times_vector_cm) - { - xarray a = xt::ones({1, 4}); - xarray b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {1, 1, 1}}; + TEST_CASE("matrix_times_vector_cm") + { + xarray a = xt::ones({1, 4}); + xarray b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {1, 1, 1}}; + + xarray e1 = {{13, 16, 19}}; + + auto r1 = linalg::dot(a, b); + CHECK_EQ(e1, r1); + + xarray c = xt::ones({3, 1}); - xarray e1 = {{13, 16, 19}}; + auto r2 = linalg::dot(b, c); + xarray e2 = {{6, 15, 24, 3}}; + e2.reshape({4, 1}); + CHECK_EQ(e2, r2); - auto r1 = linalg::dot(a, b); - EXPECT_EQ(e1, r1); + CHECK_THROWS_AS(linalg::dot(b, a), std::runtime_error); + CHECK_THROWS_AS(linalg::dot(c, b), std::runtime_error); + } - xarray c = xt::ones({3, 1}); + TEST_CASE("square_matrix_times_vector_cm") + { + xarray a = {{1, 1, 1}}; + xarray b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; - auto r2 = linalg::dot(b, c); - xarray e2 = {{6, 15, 24, 3}}; - e2.reshape({4, 1}); - EXPECT_EQ(e2, r2); + auto r1 = linalg::dot(a, b); - EXPECT_THROW(linalg::dot(b, a), std::runtime_error); - EXPECT_THROW(linalg::dot(c, b), std::runtime_error); - } + xarray e1 = {{12, 15, 18}}; + CHECK_EQ(r1, e1); - TEST(xdot, square_matrix_times_vector_cm) - { - xarray a = {{1, 1, 1}}; - xarray b = {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}; + auto r2 = linalg::dot(b, xt::transpose(a)); + xarray e2 = xarray::from_shape( + {3, 1} + ); + e2(0, 0) = 6.f; + e2(1, 0) = 15.f; + e2(2, 0) = 24.f; + CHECK_EQ(r2, e2); - auto r1 = linalg::dot(a, b); + CHECK_THROWS_AS(linalg::dot(b, a), std::runtime_error); + } - xarray e1 = {{12, 15, 18}}; - EXPECT_EQ(r1, e1); + TEST_CASE("vector_times_vector_cm") + { + xarray a = xt::ones({1, 3}); + xarray b = xt::ones({3, 1}); - auto r2 = linalg::dot(b, xt::transpose(a)); - xarray e2 = xarray::from_shape({3, 1} - ); - e2(0, 0) = 6.f; - e2(1, 0) = 15.f; - e2(2, 0) = 24.f; - EXPECT_EQ(r2, e2); + auto r1 = linalg::dot(a, b); - EXPECT_THROW(linalg::dot(b, a), std::runtime_error); - } + xarray e1 = xarray::from_shape( + {1, 1} + ); + e1(0, 0) = 3; - TEST(xdot, vector_times_vector_cm) - { - xarray a = xt::ones({1, 3}); - xarray b = xt::ones({3, 1}); + CHECK_EQ(e1, r1); - auto r1 = linalg::dot(a, b); + auto r2 = linalg::dot(b, a); + xarray e2 = xt::ones({3, 3}); + CHECK_EQ(e2, r2); - xarray e1 = xarray::from_shape({1, 1} - ); - e1(0, 0) = 3; + auto r3 = linalg::dot(b, e1); + CHECK_EQ(b * 3.f, r3); + } - EXPECT_EQ(e1, r1); + TEST_CASE("on_view") + { + xt::xarray a = xt::reshape_view(xt::arange(10 * 10 * 10), {10, 10, 10}); + xt::xarray b = xt::reshape_view(xt::arange(10 * 10 * 10), {10, 10, 10}); - auto r2 = linalg::dot(b, a); - xarray e2 = xt::ones({3, 3}); - EXPECT_EQ(e2, r2); + auto res = xt::linalg::dot(view(a, 0, 0), view(b, 0, 0)); + auto res1 = xt::linalg::dot(view(a, 0, range(0, 3)), transpose(view(b, 0, range(0, 3)))); - auto r3 = linalg::dot(b, e1); - EXPECT_EQ(b * 3.f, r3); - } + CHECK_EQ(res(0), 285.); + CHECK_EQ(res1(0, 0), 285.); + CHECK_EQ(res1(1, 2), 3635.); - TEST(xdot, on_view) - { - xt::xarray a = xt::reshape_view(xt::arange(10 * 10 * 10), {10, 10, 10}); - xt::xarray b = xt::reshape_view(xt::arange(10 * 10 * 10), {10, 10, 10}); - - auto res = xt::linalg::dot(view(a, 0, 0), view(b, 0, 0)); - auto res1 = xt::linalg::dot(view(a, 0, range(0, 3)), transpose(view(b, 0, range(0, 3)))); - - EXPECT_EQ(res(0), 285.); - EXPECT_EQ(res1(0, 0), 285.); - EXPECT_EQ(res1(1, 2), 3635.); - - EXPECT_EQ(res1.dimension(), 2u); - EXPECT_EQ(res1.shape()[0], 3u); - EXPECT_EQ(res1.shape()[1], 3u); - - auto res2 = xt::linalg::dot(strided_view(a, {0, 0}), strided_view(b, {0, 0})); - auto res3 = xt::linalg::dot( - strided_view(a, {0, range(0, 3)}), - transpose(strided_view(b, {0, range(0, 3)})) - ); - EXPECT_EQ(res2(0), 285.); - EXPECT_EQ(res3(0, 0), 285.); - EXPECT_EQ(res3(1, 2), 3635.); - - EXPECT_EQ(res3.dimension(), 2u); - EXPECT_EQ(res3.shape()[0], 3u); - EXPECT_EQ(res3.shape()[1], 3u); - } + CHECK_EQ(res1.dimension(), 2u); + CHECK_EQ(res1.shape()[0], 3u); + CHECK_EQ(res1.shape()[1], 3u); + auto res2 = xt::linalg::dot(strided_view(a, {0, 0}), strided_view(b, {0, 0})); + auto res3 = xt::linalg::dot( + strided_view(a, {0, range(0, 3)}), + transpose(strided_view(b, {0, range(0, 3)})) + ); + CHECK_EQ(res2(0), 285.); + CHECK_EQ(res3(0, 0), 285.); + CHECK_EQ(res3(1, 2), 3635.); + + CHECK_EQ(res3.dimension(), 2u); + CHECK_EQ(res3.shape()[0], 3u); + CHECK_EQ(res3.shape()[1], 3u); + } + } } // namespace xt diff --git a/test/test_dot_extended.cpp b/test/test_dot_extended.cpp index c93bc1e..3e8aef2 100644 --- a/test/test_dot_extended.cpp +++ b/test/test_dot_extended.cpp @@ -6,184 +6,209 @@ * * * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ +// This file is generated from test/files/cppy_source/test_dot_extended.cppy by preprocess.py! -// This file is generated from test/files/cppy_source/test_dot_extended.cppy by -// preprocess.py! #include #include "xtensor/containers/xarray.hpp" #include "xtensor/containers/xtensor.hpp" -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor-blas/xlinalg.hpp" namespace xt { using namespace xt::placeholders; - /*py - a = np.random.random((2, 3, 5)) - b = np.random.random((4, 5)) - dr = np.dot(a, b.T) - */ - TEST(xtest_extended, dot_broadcast) + TEST_SUITE("xdot_extended") { - // py_a - xarray py_a = { - {{0.3745401188473625, 0.9507143064099162, 0.7319939418114051, 0.5986584841970366, 0.1560186404424365}, - {0.1559945203362026, 0.0580836121681995, 0.8661761457749352, 0.6011150117432088, 0.7080725777960455}, - {0.0205844942958024, 0.9699098521619943, 0.8324426408004217, 0.2123391106782762, 0.1818249672071006 - }}, - - {{0.1834045098534338, 0.3042422429595377, 0.5247564316322378, 0.4319450186421158, 0.2912291401980419}, - {0.6118528947223795, 0.1394938606520418, 0.2921446485352182, 0.3663618432936917, 0.4560699842170359}, - {0.7851759613930136, 0.1996737821583597, 0.5142344384136116, 0.5924145688620425, 0.0464504127199977}} - }; - // py_b - xarray py_b = { - {0.6075448519014384, 0.1705241236872915, 0.0650515929852795, 0.9488855372533332, 0.9656320330745594}, - {0.8083973481164611, 0.3046137691733707, 0.0976721140063839, 0.6842330265121569, 0.4401524937396013}, - {0.1220382348447788, 0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169}, - {0.662522284353982, 0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527} - }; - // py_dr - xarray py_dr = { - {{1.1560019913607258, 1.1421672030085086, 1.1263990512143978, 1.2813094834150083}, - {1.415151366639716, 0.9513625344824885, 0.807426629014782, 1.0314517921651605}, - {0.6091122748507029, 0.6187149240291543, 0.7515524775267591, 0.898595256683809}}, - - {{0.8885299172713558, 0.7159304454839006, 0.6592223836380569, 0.7792380767202456}, - {1.2025508600129964, 1.0170636073271262, 0.6049520893427571, 0.8853834024749684}, - {1.15151820221699, 1.1715787914743192, 0.763094187597877, 1.182339688054495}} - }; - - xt::xtensor bas = xt::transpose(py_b); - - auto xres = xt::linalg::dot(py_a, xt::transpose(py_b)); - auto xres2 = xt::linalg::dot(py_a, bas); - std::cout << xres << std::endl; - EXPECT_TRUE(xt::allclose(xres, py_dr)); - EXPECT_TRUE(xt::allclose(xres2, py_dr)); - } + /*py + a = np.random.random((2, 3, 5)) + b = np.random.random((4, 5)) + dr = np.dot(a, b.T) + */ + TEST_CASE("dot_broadcast") + { + // py_a + xarray py_a = { + {{0.3745401188473625, 0.9507143064099162, 0.7319939418114051, 0.5986584841970366, 0.1560186404424365 + }, + {0.1559945203362026, 0.0580836121681995, 0.8661761457749352, 0.6011150117432088, 0.7080725777960455 + }, + {0.0205844942958024, 0.9699098521619943, 0.8324426408004217, 0.2123391106782762, 0.1818249672071006 + }}, + + {{0.1834045098534338, 0.3042422429595377, 0.5247564316322378, 0.4319450186421158, 0.2912291401980419 + }, + {0.6118528947223795, 0.1394938606520418, 0.2921446485352182, 0.3663618432936917, 0.4560699842170359 + }, + {0.7851759613930136, 0.1996737821583597, 0.5142344384136116, 0.5924145688620425, 0.0464504127199977 + }} + }; + // py_b + xarray py_b = { + {0.6075448519014384, 0.1705241236872915, 0.0650515929852795, 0.9488855372533332, 0.9656320330745594 + }, + {0.8083973481164611, 0.3046137691733707, 0.0976721140063839, 0.6842330265121569, 0.4401524937396013 + }, + {0.1220382348447788, 0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169 + }, + {0.662522284353982, 0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527} + }; + // py_dr + xarray py_dr = { + {{1.1560019913607258, 1.1421672030085086, 1.1263990512143978, 1.2813094834150083}, + {1.415151366639716, 0.9513625344824885, 0.807426629014782, 1.0314517921651605}, + {0.6091122748507029, 0.6187149240291543, 0.7515524775267591, 0.898595256683809}}, + + {{0.8885299172713558, 0.7159304454839006, 0.6592223836380569, 0.7792380767202456}, + {1.2025508600129964, 1.0170636073271262, 0.6049520893427571, 0.8853834024749684}, + {1.15151820221699, 1.1715787914743192, 0.763094187597877, 1.182339688054495}} + }; + + xt::xtensor bas = xt::transpose(py_b); + + auto xres = xt::linalg::dot(py_a, xt::transpose(py_b)); + auto xres2 = xt::linalg::dot(py_a, bas); + CHECK(xt::allclose(xres, py_dr)); + CHECK(xt::allclose(xres2, py_dr)); + } + + /*py + a = np.random.random((2, 3, 5)) + b = np.random.random((5)) + dr = np.dot(a, b) + */ + TEST_CASE("dot_broadcast_2") + { + // py_a + xarray py_a = { + {{0.9695846277645586, 0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851 + }, + {0.9218742350231168, 0.0884925020519195, 0.1959828624191452, 0.0452272889105381, 0.3253303307632643 + }, + {0.388677289689482, 0.2713490317738959, 0.8287375091519293, 0.3567533266935893, 0.2809345096873808 + }}, + + {{0.5426960831582485, 0.1409242249747626, 0.8021969807540397, 0.0745506436797708, 0.9868869366005173 + }, + {0.7722447692966574, 0.1987156815341724, 0.0055221171236024, 0.8154614284548342, 0.7068573438476171 + }, + {0.7290071680409873, 0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297 + }} + }; + // py_b + xarray py_b = { + 0.8631034258755935, + 0.6232981268275579, + 0.3308980248526492, + 0.0635583502860236, + 0.3109823217156622 + }; + // py_dr + xarray py_dr = { + {1.8736790686065976, 1.0197269167779506, 0.8888679673881792}, + {1.1333287572487494, 1.0638629967411402, 1.1932578950872312} + }; - /*py - a = np.random.random((2, 3, 5)) - b = np.random.random((5)) - dr = np.dot(a, b) - */ - TEST(xtest_extended, dot_broadcast_2) - { - // py_a - xarray py_a = { - {{0.9695846277645586, 0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851}, - {0.9218742350231168, 0.0884925020519195, 0.1959828624191452, 0.0452272889105381, 0.3253303307632643}, - {0.388677289689482, 0.2713490317738959, 0.8287375091519293, 0.3567533266935893, 0.2809345096873808}}, - - {{0.5426960831582485, 0.1409242249747626, 0.8021969807540397, 0.0745506436797708, 0.9868869366005173}, - {0.7722447692966574, 0.1987156815341724, 0.0055221171236024, 0.8154614284548342, 0.7068573438476171}, - {0.7290071680409873, 0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297}} - }; - // py_b - xarray py_b = - {0.8631034258755935, 0.6232981268275579, 0.3308980248526492, 0.0635583502860236, 0.3109823217156622}; - // py_dr - xarray py_dr = { - {1.8736790686065976, 1.0197269167779506, 0.8888679673881792}, - {1.1333287572487494, 1.0638629967411402, 1.1932578950872312} - }; - - auto xres = xt::linalg::dot(py_a, py_b); - std::cout << xres << std::endl; - EXPECT_TRUE(xt::allclose(xres, py_dr)); - } + auto xres = xt::linalg::dot(py_a, py_b); + CHECK(xt::allclose(xres, py_dr)); + } - /*py - a = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5) - b = np.arange(4 * 5 * 3).reshape(4, 5, 3) - dr = np.dot(a, b) - */ - TEST(xtest_extended, dot_broadcast_3) - { - // py_a - xarray py_a = { - {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}, {10, 11, 12, 13, 14}, {15, 16, 17, 18, 19}}, + /*py + a = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5) + b = np.arange(4 * 5 * 3).reshape(4, 5, 3) + dr = np.dot(a, b) + */ + TEST_CASE("dot_broadcast_3") + { + // py_a + xarray py_a = { + {{{0, 1, 2, 3, 4}, {5, 6, 7, 8, 9}, {10, 11, 12, 13, 14}, {15, 16, 17, 18, 19}}, + + {{20, 21, 22, 23, 24}, {25, 26, 27, 28, 29}, {30, 31, 32, 33, 34}, {35, 36, 37, 38, 39}}, + + {{40, 41, 42, 43, 44}, {45, 46, 47, 48, 49}, {50, 51, 52, 53, 54}, {55, 56, 57, 58, 59}}}, + + + {{{60, 61, 62, 63, 64}, {65, 66, 67, 68, 69}, {70, 71, 72, 73, 74}, {75, 76, 77, 78, 79}}, + + {{80, 81, 82, 83, 84}, {85, 86, 87, 88, 89}, {90, 91, 92, 93, 94}, {95, 96, 97, 98, 99}}, + + {{100, 101, 102, 103, 104}, + {105, 106, 107, 108, 109}, + {110, 111, 112, 113, 114}, + {115, 116, 117, 118, 119}}} + }; + // py_b + xarray py_b = { + {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}, {12, 13, 14}}, - {{20, 21, 22, 23, 24}, {25, 26, 27, 28, 29}, {30, 31, 32, 33, 34}, {35, 36, 37, 38, 39}}, + {{15, 16, 17}, {18, 19, 20}, {21, 22, 23}, {24, 25, 26}, {27, 28, 29}}, - {{40, 41, 42, 43, 44}, {45, 46, 47, 48, 49}, {50, 51, 52, 53, 54}, {55, 56, 57, 58, 59}}}, + {{30, 31, 32}, {33, 34, 35}, {36, 37, 38}, {39, 40, 41}, {42, 43, 44}}, - {{{60, 61, 62, 63, 64}, {65, 66, 67, 68, 69}, {70, 71, 72, 73, 74}, {75, 76, 77, 78, 79}}, + {{45, 46, 47}, {48, 49, 50}, {51, 52, 53}, {54, 55, 56}, {57, 58, 59}} + }; + // py_dr + xarray py_dr = { + {{{{90, 100, 110}, {240, 250, 260}, {390, 400, 410}, {540, 550, 560}}, - {{80, 81, 82, 83, 84}, {85, 86, 87, 88, 89}, {90, 91, 92, 93, 94}, {95, 96, 97, 98, 99}}, + {{240, 275, 310}, {765, 800, 835}, {1290, 1325, 1360}, {1815, 1850, 1885}}, - {{100, 101, 102, 103, 104}, - {105, 106, 107, 108, 109}, - {110, 111, 112, 113, 114}, - {115, 116, 117, 118, 119}}} - }; - // py_b - xarray py_b = { - {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}, {9, 10, 11}, {12, 13, 14}}, + {{390, 450, 510}, {1290, 1350, 1410}, {2190, 2250, 2310}, {3090, 3150, 3210}}, - {{15, 16, 17}, {18, 19, 20}, {21, 22, 23}, {24, 25, 26}, {27, 28, 29}}, + {{540, 625, 710}, {1815, 1900, 1985}, {3090, 3175, 3260}, {4365, 4450, 4535}}}, - {{30, 31, 32}, {33, 34, 35}, {36, 37, 38}, {39, 40, 41}, {42, 43, 44}}, - {{45, 46, 47}, {48, 49, 50}, {51, 52, 53}, {54, 55, 56}, {57, 58, 59}} - }; - // py_dr - xarray py_dr = { - {{{{90, 100, 110}, {240, 250, 260}, {390, 400, 410}, {540, 550, 560}}, + {{{690, 800, 910}, {2340, 2450, 2560}, {3990, 4100, 4210}, {5640, 5750, 5860}}, - {{240, 275, 310}, {765, 800, 835}, {1290, 1325, 1360}, {1815, 1850, 1885}}, + {{840, 975, 1110}, {2865, 3000, 3135}, {4890, 5025, 5160}, {6915, 7050, 7185}}, - {{390, 450, 510}, {1290, 1350, 1410}, {2190, 2250, 2310}, {3090, 3150, 3210}}, + {{990, 1150, 1310}, {3390, 3550, 3710}, {5790, 5950, 6110}, {8190, 8350, 8510}}, - {{540, 625, 710}, {1815, 1900, 1985}, {3090, 3175, 3260}, {4365, 4450, 4535}}}, + {{1140, 1325, 1510}, {3915, 4100, 4285}, {6690, 6875, 7060}, {9465, 9650, 9835}}}, - {{{690, 800, 910}, {2340, 2450, 2560}, {3990, 4100, 4210}, {5640, 5750, 5860}}, - {{840, 975, 1110}, {2865, 3000, 3135}, {4890, 5025, 5160}, {6915, 7050, 7185}}, + {{{1290, 1500, 1710}, {4440, 4650, 4860}, {7590, 7800, 8010}, {10740, 10950, 11160}}, - {{990, 1150, 1310}, {3390, 3550, 3710}, {5790, 5950, 6110}, {8190, 8350, 8510}}, + {{1440, 1675, 1910}, {4965, 5200, 5435}, {8490, 8725, 8960}, {12015, 12250, 12485}}, - {{1140, 1325, 1510}, {3915, 4100, 4285}, {6690, 6875, 7060}, {9465, 9650, 9835}}}, + {{1590, 1850, 2110}, {5490, 5750, 6010}, {9390, 9650, 9910}, {13290, 13550, 13810}}, - {{{1290, 1500, 1710}, {4440, 4650, 4860}, {7590, 7800, 8010}, {10740, 10950, 11160}}, + {{1740, 2025, 2310}, {6015, 6300, 6585}, {10290, 10575, 10860}, {14565, 14850, 15135}}}}, - {{1440, 1675, 1910}, {4965, 5200, 5435}, {8490, 8725, 8960}, {12015, 12250, 12485}}, - {{1590, 1850, 2110}, {5490, 5750, 6010}, {9390, 9650, 9910}, {13290, 13550, 13810}}, + {{{{1890, 2200, 2510}, {6540, 6850, 7160}, {11190, 11500, 11810}, {15840, 16150, 16460}}, - {{1740, 2025, 2310}, {6015, 6300, 6585}, {10290, 10575, 10860}, {14565, 14850, 15135}}}}, + {{2040, 2375, 2710}, {7065, 7400, 7735}, {12090, 12425, 12760}, {17115, 17450, 17785}}, - {{{{1890, 2200, 2510}, {6540, 6850, 7160}, {11190, 11500, 11810}, {15840, 16150, 16460}}, + {{2190, 2550, 2910}, {7590, 7950, 8310}, {12990, 13350, 13710}, {18390, 18750, 19110}}, - {{2040, 2375, 2710}, {7065, 7400, 7735}, {12090, 12425, 12760}, {17115, 17450, 17785}}, + {{2340, 2725, 3110}, {8115, 8500, 8885}, {13890, 14275, 14660}, {19665, 20050, 20435}}}, - {{2190, 2550, 2910}, {7590, 7950, 8310}, {12990, 13350, 13710}, {18390, 18750, 19110}}, - {{2340, 2725, 3110}, {8115, 8500, 8885}, {13890, 14275, 14660}, {19665, 20050, 20435}}}, + {{{2490, 2900, 3310}, {8640, 9050, 9460}, {14790, 15200, 15610}, {20940, 21350, 21760}}, - {{{2490, 2900, 3310}, {8640, 9050, 9460}, {14790, 15200, 15610}, {20940, 21350, 21760}}, + {{2640, 3075, 3510}, {9165, 9600, 10035}, {15690, 16125, 16560}, {22215, 22650, 23085}}, - {{2640, 3075, 3510}, {9165, 9600, 10035}, {15690, 16125, 16560}, {22215, 22650, 23085}}, + {{2790, 3250, 3710}, {9690, 10150, 10610}, {16590, 17050, 17510}, {23490, 23950, 24410}}, - {{2790, 3250, 3710}, {9690, 10150, 10610}, {16590, 17050, 17510}, {23490, 23950, 24410}}, + {{2940, 3425, 3910}, {10215, 10700, 11185}, {17490, 17975, 18460}, {24765, 25250, 25735}}}, - {{2940, 3425, 3910}, {10215, 10700, 11185}, {17490, 17975, 18460}, {24765, 25250, 25735}}}, - {{{3090, 3600, 4110}, {10740, 11250, 11760}, {18390, 18900, 19410}, {26040, 26550, 27060}}, + {{{3090, 3600, 4110}, {10740, 11250, 11760}, {18390, 18900, 19410}, {26040, 26550, 27060}}, - {{3240, 3775, 4310}, {11265, 11800, 12335}, {19290, 19825, 20360}, {27315, 27850, 28385}}, + {{3240, 3775, 4310}, {11265, 11800, 12335}, {19290, 19825, 20360}, {27315, 27850, 28385}}, - {{3390, 3950, 4510}, {11790, 12350, 12910}, {20190, 20750, 21310}, {28590, 29150, 29710}}, + {{3390, 3950, 4510}, {11790, 12350, 12910}, {20190, 20750, 21310}, {28590, 29150, 29710}}, - {{3540, 4125, 4710}, {12315, 12900, 13485}, {21090, 21675, 22260}, {29865, 30450, 31035}}}} - }; + {{3540, 4125, 4710}, {12315, 12900, 13485}, {21090, 21675, 22260}, {29865, 30450, 31035}}}} + }; - auto xres = xt::linalg::dot(py_a, py_b); - EXPECT_TRUE(xt::allclose(xres, py_dr)); + auto xres = xt::linalg::dot(py_a, py_b); + CHECK(xt::allclose(xres, py_dr)); + } } -} // namespace xt +} diff --git a/test/test_float_norm.cpp b/test/test_float_norm.cpp index c087238..4597c67 100644 --- a/test/test_float_norm.cpp +++ b/test/test_float_norm.cpp @@ -14,25 +14,27 @@ #include "xtensor/generators/xrandom.hpp" #include "xtensor/views/xview.hpp" -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor-blas/xblas.hpp" #include "xtensor-blas/xlinalg.hpp" namespace xt { - TEST(xblas, norm_complex_float) + TEST_SUITE("xblas") { - xt::xarray> a = {std::complex(1.0f, 2.0f), std::complex(3.0f, 4.0f)}; - auto res = linalg::norm(a); - - EXPECT_NEAR(res.real(), 5.4772f, 1e-3f); - EXPECT_NEAR(res.imag(), 0.0f, 1e-3f); + TEST_CASE("norm_complex_float") + { + xt::xarray> a = {std::complex(1.0f, 2.0f), std::complex(3.0f, 4.0f)}; + auto res = linalg::norm(a); + + CHECK(res.real() == doctest::Approx(5.4772f).epsilon(1e-3f)); + CHECK(res.imag() == doctest::Approx(0.0f).epsilon(1e-3f)); + } + + TEST_CASE("norm_float_arange") + { + xt::linalg::norm(xt::arange(15), 1); + } } - - TEST(xblas, norm_float_arange) - { - xt::linalg::norm(xt::arange(15), 1); - } - } // namespace xt diff --git a/test/test_generator/cppy_source/test_dot_extended.cppy b/test/test_generator/cppy_source/test_dot_extended.cppy index 20cbd20..22783cf 100644 --- a/test/test_generator/cppy_source/test_dot_extended.cppy +++ b/test/test_generator/cppy_source/test_dot_extended.cppy @@ -9,7 +9,7 @@ #include -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor/containers/xarray.hpp" #include "xtensor/containers/xtensor.hpp" @@ -19,53 +19,55 @@ namespace xt { using namespace xt::placeholders; - /*py - a = np.random.random((2, 3, 5)) - b = np.random.random((4, 5)) - dr = np.dot(a, b.T) - */ - TEST(xtest_extended, dot_broadcast) + TEST_SUITE("xdot_extended") { - // py_a - // py_b - // py_dr + /*py + a = np.random.random((2, 3, 5)) + b = np.random.random((4, 5)) + dr = np.dot(a, b.T) + */ + TEST_CASE("dot_broadcast") + { + // py_a + // py_b + // py_dr - xt::xtensor bas = xt::transpose(py_b); + xt::xtensor bas = xt::transpose(py_b); - auto xres = xt::linalg::dot(py_a, xt::transpose(py_b)); - auto xres2 = xt::linalg::dot(py_a, bas); - std::cout << xres << std::endl; - EXPECT_TRUE(xt::allclose(xres, py_dr)); - EXPECT_TRUE(xt::allclose(xres2, py_dr)); - } + auto xres = xt::linalg::dot(py_a, xt::transpose(py_b)); + auto xres2 = xt::linalg::dot(py_a, bas); + CHECK(xt::allclose(xres, py_dr)); + CHECK(xt::allclose(xres2, py_dr)); + } - /*py - a = np.random.random((2, 3, 5)) - b = np.random.random((5)) - dr = np.dot(a, b) - */ - TEST(xtest_extended, dot_broadcast_2) - { - // py_a - // py_b - // py_dr + /*py + a = np.random.random((2, 3, 5)) + b = np.random.random((5)) + dr = np.dot(a, b) + */ + TEST_CASE("dot_broadcast_2") + { + // py_a + // py_b + // py_dr - auto xres = xt::linalg::dot(py_a, py_b); - std::cout << xres << std::endl; - EXPECT_TRUE(xt::allclose(xres, py_dr)); - } - /*py - a = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5) - b = np.arange(4 * 5 * 3).reshape(4, 5, 3) - dr = np.dot(a, b) - */ - TEST(xtest_extended, dot_broadcast_3) - { - // py_a - // py_b - // py_dr + auto xres = xt::linalg::dot(py_a, py_b); + CHECK(xt::allclose(xres, py_dr)); + } + + /*py + a = np.arange(2 * 3 * 4 * 5).reshape(2, 3, 4, 5) + b = np.arange(4 * 5 * 3).reshape(4, 5, 3) + dr = np.dot(a, b) + */ + TEST_CASE("dot_broadcast_3") + { + // py_a + // py_b + // py_dr - auto xres = xt::linalg::dot(py_a, py_b); - EXPECT_TRUE(xt::allclose(xres, py_dr)); + auto xres = xt::linalg::dot(py_a, py_b); + CHECK(xt::allclose(xres, py_dr)); + } } } diff --git a/test/test_generator/cppy_source/test_lstsq.cppy b/test/test_generator/cppy_source/test_lstsq.cppy index ce504eb..e924c8a 100644 --- a/test/test_generator/cppy_source/test_lstsq.cppy +++ b/test/test_generator/cppy_source/test_lstsq.cppy @@ -9,7 +9,7 @@ #include -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor/containers/xarray.hpp" #include "xtensor/containers/xfixed.hpp" #include "xtensor/core/xnoalias.hpp" @@ -23,170 +23,169 @@ namespace xt { using namespace xt::placeholders; - /*py - a = np.random.random((6, 3)) - b = np.ones((6)) - */ - TEST(xtest_extended, lstsq1) + TEST_SUITE("xlstsq_extended") { - // py_a - // py_b - // py_res0 = np.linalg.lstsq(a, b)[0] - // py_res1 = np.linalg.lstsq(a, b)[1] - // py_res2 = np.linalg.lstsq(a, b)[2] - // py_res3 = np.linalg.lstsq(a, b)[3] - - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); + /*py + a = np.random.random((6, 3)) + b = np.ones((6)) + */ + TEST_CASE("lstsq1") + { + // py_a + // py_b + // py_res0 = np.linalg.lstsq(a, b)[0] + // py_res1 = np.linalg.lstsq(a, b)[1] + // py_res2 = np.linalg.lstsq(a, b)[2] + // py_res3 = np.linalg.lstsq(a, b)[3] + + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } + + /*py + a = np.random.random((3, 3)) + b = np.ones((3)) + */ + TEST_CASE("lstsq20") + { + // py_a + // py_b + // py_res0 = np.linalg.lstsq(a, b)[0] + // py_res1 = np.linalg.lstsq(a, b)[1] + // py_res2 = np.linalg.lstsq(a, b)[2] + // py_res3 = np.linalg.lstsq(a, b)[3] + + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } + + /*py + a = np.random.random((3, 3)) + b = np.ones((3, 3)) + */ + TEST_CASE("lstsq21") + { + // py_a + // py_b + // py_res0 = np.linalg.lstsq(a, b)[0] + // py_res1 = np.linalg.lstsq(a, b)[1] + // py_res2 = np.linalg.lstsq(a, b)[2] + // py_res3 = np.linalg.lstsq(a, b)[3] + + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } + + /*py + a = np.random.random((2, 5)) + b = np.ones((2)) + */ + TEST_CASE("lstsq3") + { + // py_a + // py_b + // py_res0 = np.linalg.lstsq(a, b)[0] + // py_res1 = np.linalg.lstsq(a, b)[1] + // py_res2 = np.linalg.lstsq(a, b)[2] + // py_res3 = np.linalg.lstsq(a, b)[3] + + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } + + /*py + a = np.random.random((2, 5)) + b = np.ones((2, 10)) + */ + TEST_CASE("lstsq4") + { + // py_a + // py_b + // py_res0 = np.linalg.lstsq(a, b)[0] + // py_res1 = np.linalg.lstsq(a, b)[1] + // py_res2 = np.linalg.lstsq(a, b)[2] + // py_res3 = np.linalg.lstsq(a, b)[3] + + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } + + /*py + a = np.random.random((10, 5)) + b = np.ones((10, 20)) + */ + TEST_CASE("lstsq5") + { + // py_a + // py_b + // py_res0 = np.linalg.lstsq(a, b)[0] + // py_res1 = np.linalg.lstsq(a, b)[1] + // py_res2 = np.linalg.lstsq(a, b)[2] + // py_res3 = np.linalg.lstsq(a, b)[3] + + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } + + /*py + a = np.array([[0., 1.]]) + b = np.array([1.]) + */ + TEST_CASE("lstsq6") + { + // py_a + // py_b + // py_res0 = np.linalg.lstsq(a, b)[0] + // py_res1 = np.linalg.lstsq(a, b)[1] + // py_res2 = np.linalg.lstsq(a, b)[2] + // py_res3 = np.linalg.lstsq(a, b)[3] + + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } + + /*py + a = np.array([[1.], [1.]]) + b = np.array([1., 1.]) + */ + TEST_CASE("lstsq7") + { + // cannot use "// py_a" due to ambiguous initializer list conversion below + // xarray py_a = {{1.}, + // {1.}}; + xarray py_a = xt::ones({2, 1}); + // py_b + // py_res0 = np.linalg.lstsq(a, b)[0] + // py_res1 = np.linalg.lstsq(a, b)[1] + // py_res2 = np.linalg.lstsq(a, b)[2] + // py_res3 = np.linalg.lstsq(a, b)[3] + + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } } - - /*py - a = np.random.random((3, 3)) - b = np.ones((3)) - */ - TEST(xtest_extended, lstsq20) - { - // py_a - // py_b - // py_res0 = np.linalg.lstsq(a, b)[0] - // py_res1 = np.linalg.lstsq(a, b)[1] - // py_res2 = np.linalg.lstsq(a, b)[2] - // py_res3 = np.linalg.lstsq(a, b)[3] - - auto xres = xt::linalg::lstsq(py_a, py_b); - - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } - - /*py - a = np.random.random((3, 3)) - b = np.ones((3, 3)) - */ - TEST(xtest_extended, lstsq21) - { - // py_a - // py_b - // py_res0 = np.linalg.lstsq(a, b)[0] - // py_res1 = np.linalg.lstsq(a, b)[1] - // py_res2 = np.linalg.lstsq(a, b)[2] - // py_res3 = np.linalg.lstsq(a, b)[3] - - auto xres = xt::linalg::lstsq(py_a, py_b); - - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } - - /*py - a = np.random.random((2, 5)) - b = np.ones((2)) - */ - TEST(xtest_extended, lstsq3) - { - // py_a - // py_b - // py_res0 = np.linalg.lstsq(a, b)[0] - // py_res1 = np.linalg.lstsq(a, b)[1] - // py_res2 = np.linalg.lstsq(a, b)[2] - // py_res3 = np.linalg.lstsq(a, b)[3] - - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } - - /*py - a = np.random.random((2, 5)) - b = np.ones((2, 10)) - */ - TEST(xtest_extended, lstsq4) - { - // py_a - // py_b - // py_res0 = np.linalg.lstsq(a, b)[0] - // py_res1 = np.linalg.lstsq(a, b)[1] - // py_res2 = np.linalg.lstsq(a, b)[2] - // py_res3 = np.linalg.lstsq(a, b)[3] - - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } - - /*py - a = np.random.random((10, 5)) - b = np.ones((10, 20)) - */ - TEST(xtest_extended, lstsq5) - { - // py_a - // py_b - // py_res0 = np.linalg.lstsq(a, b)[0] - // py_res1 = np.linalg.lstsq(a, b)[1] - // py_res2 = np.linalg.lstsq(a, b)[2] - // py_res3 = np.linalg.lstsq(a, b)[3] - - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } - - /*py - a = np.array([[0., 1.]]) - b = np.array([1.]) - */ - TEST(xtest_extended, lstsq6) - { - // py_a - // py_b - // py_res0 = np.linalg.lstsq(a, b)[0] - // py_res1 = np.linalg.lstsq(a, b)[1] - // py_res2 = np.linalg.lstsq(a, b)[2] - // py_res3 = np.linalg.lstsq(a, b)[3] - - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } - - /*py - a = np.array([[1.], [1.]]) - b = np.array([1., 1.]) - */ - TEST(xtest_extended, lstsq7) - { - // cannot use "// py_a" due to ambiguous initializer list conversion below - // xarray py_a = {{1.}, - // {1.}}; - xarray py_a = xt::ones({2, 1}); - // py_b - // py_res0 = np.linalg.lstsq(a, b)[0] - // py_res1 = np.linalg.lstsq(a, b)[1] - // py_res2 = np.linalg.lstsq(a, b)[2] - // py_res3 = np.linalg.lstsq(a, b)[3] - - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } - - } diff --git a/test/test_generator/cppy_source/test_qr.cppy b/test/test_generator/cppy_source/test_qr.cppy index b81573e..bb218f6 100644 --- a/test/test_generator/cppy_source/test_qr.cppy +++ b/test/test_generator/cppy_source/test_qr.cppy @@ -9,7 +9,7 @@ #include -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor/containers/xarray.hpp" #include "xtensor/containers/xfixed.hpp" #include "xtensor/core/xnoalias.hpp" @@ -23,77 +23,80 @@ namespace xt { using namespace xt::placeholders; - /*py - a = np.random.random((6, 3)) - res_q1 = np.linalg.qr(a, 'raw') - res_q2 = np.linalg.qr(a, 'complete') - res_q3 = np.linalg.qr(a, 'reduced') - res_q4 = np.linalg.qr(a, 'r') - */ - TEST(xtest_extended, qr1) + TEST_SUITE("xqr_extended") { - // py_a - // py_resq1_h = res_q1[0] - // py_resq1_tau = res_q1[1] - - auto res1 = linalg::qr(py_a, linalg::qrmode::raw); - EXPECT_TRUE(allclose(std::get<0>(res1), py_resq1_h)); - EXPECT_TRUE(allclose(std::get<1>(res1), py_resq1_tau)); - - // py_resq2_q_cmpl = res_q2[0] - // py_resq2_r_cmpl = res_q2[1] - - auto res2 = linalg::qr(py_a, linalg::qrmode::complete); - EXPECT_TRUE(allclose(std::get<0>(res2), py_resq2_q_cmpl)); - EXPECT_TRUE(allclose(std::get<1>(res2), py_resq2_r_cmpl)); - - // py_resq3_q_cmpl = res_q3[0] - // py_resq3_r_cmpl = res_q3[1] - - auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); - EXPECT_TRUE(allclose(std::get<0>(res3), py_resq3_q_cmpl)); - EXPECT_TRUE(allclose(std::get<1>(res3), py_resq3_r_cmpl)); - - // py_resq4_r_r = res_q4 - - auto res4 = linalg::qr(py_a, linalg::qrmode::r); - EXPECT_TRUE(allclose(std::get<1>(res4), py_resq4_r_r)); - } - - /*py - a = np.random.random((5, 10)) - res_q1 = np.linalg.qr(a, 'raw') - res_q2 = np.linalg.qr(a, 'complete') - res_q3 = np.linalg.qr(a, 'reduced') - res_q4 = np.linalg.qr(a, 'r') - */ - TEST(xtest_extended, qr2) - { - // py_a - // py_resq1_h = res_q1[0] - // py_resq1_tau = res_q1[1] - - auto res1 = linalg::qr(py_a, linalg::qrmode::raw); - EXPECT_TRUE(allclose(std::get<0>(res1), py_resq1_h)); - EXPECT_TRUE(allclose(std::get<1>(res1), py_resq1_tau)); - // py_resq2_q_cmpl = res_q2[0] - // py_resq2_r_cmpl = res_q2[1] - - auto res2 = linalg::qr(py_a, linalg::qrmode::complete); - EXPECT_TRUE(allclose(std::get<0>(res2), py_resq2_q_cmpl)); - EXPECT_TRUE(allclose(std::get<1>(res2), py_resq2_r_cmpl)); - - // py_resq3_q_cmpl = res_q3[0] - // py_resq3_r_cmpl = res_q3[1] - - auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); - EXPECT_TRUE(allclose(std::get<0>(res3), py_resq3_q_cmpl)); - EXPECT_TRUE(allclose(std::get<1>(res3), py_resq3_r_cmpl)); - - // py_resq4_r_r = res_q4 - - auto res4 = linalg::qr(py_a, linalg::qrmode::r); - EXPECT_TRUE(allclose(std::get<1>(res4), py_resq4_r_r)); - + /*py + a = np.random.random((6, 3)) + res_q1 = np.linalg.qr(a, 'raw') + res_q2 = np.linalg.qr(a, 'complete') + res_q3 = np.linalg.qr(a, 'reduced') + res_q4 = np.linalg.qr(a, 'r') + */ + TEST_CASE("qr1") + { + // py_a + // py_resq1_h = res_q1[0] + // py_resq1_tau = res_q1[1] + + auto res1 = linalg::qr(py_a, linalg::qrmode::raw); + CHECK(allclose(std::get<0>(res1), py_resq1_h)); + CHECK(allclose(std::get<1>(res1), py_resq1_tau)); + + // py_resq2_q_cmpl = res_q2[0] + // py_resq2_r_cmpl = res_q2[1] + + auto res2 = linalg::qr(py_a, linalg::qrmode::complete); + CHECK(allclose(std::get<0>(res2), py_resq2_q_cmpl)); + CHECK(allclose(std::get<1>(res2), py_resq2_r_cmpl)); + + // py_resq3_q_cmpl = res_q3[0] + // py_resq3_r_cmpl = res_q3[1] + + auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); + CHECK(allclose(std::get<0>(res3), py_resq3_q_cmpl)); + CHECK(allclose(std::get<1>(res3), py_resq3_r_cmpl)); + + // py_resq4_r_r = res_q4 + + auto res4 = linalg::qr(py_a, linalg::qrmode::r); + CHECK(allclose(std::get<1>(res4), py_resq4_r_r)); + } + + /*py + a = np.random.random((5, 10)) + res_q1 = np.linalg.qr(a, 'raw') + res_q2 = np.linalg.qr(a, 'complete') + res_q3 = np.linalg.qr(a, 'reduced') + res_q4 = np.linalg.qr(a, 'r') + */ + TEST_CASE("qr2") + { + // py_a + // py_resq1_h = res_q1[0] + // py_resq1_tau = res_q1[1] + + auto res1 = linalg::qr(py_a, linalg::qrmode::raw); + CHECK(allclose(std::get<0>(res1), py_resq1_h)); + CHECK(allclose(std::get<1>(res1), py_resq1_tau)); + // py_resq2_q_cmpl = res_q2[0] + // py_resq2_r_cmpl = res_q2[1] + + auto res2 = linalg::qr(py_a, linalg::qrmode::complete); + CHECK(allclose(std::get<0>(res2), py_resq2_q_cmpl)); + CHECK(allclose(std::get<1>(res2), py_resq2_r_cmpl)); + + // py_resq3_q_cmpl = res_q3[0] + // py_resq3_r_cmpl = res_q3[1] + + auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); + CHECK(allclose(std::get<0>(res3), py_resq3_q_cmpl)); + CHECK(allclose(std::get<1>(res3), py_resq3_r_cmpl)); + + // py_resq4_r_r = res_q4 + + auto res4 = linalg::qr(py_a, linalg::qrmode::r); + CHECK(allclose(std::get<1>(res4), py_resq4_r_r)); + + } } } diff --git a/test/test_lapack.cpp b/test/test_lapack.cpp index dcf507b..9b1bd32 100644 --- a/test/test_lapack.cpp +++ b/test/test_lapack.cpp @@ -13,7 +13,7 @@ #include "xtensor/misc/xcomplex.hpp" #include "xtensor/views/xview.hpp" -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor-blas/xblas.hpp" #include "xtensor-blas/xlapack.hpp" #include "xtensor-blas/xlinalg.hpp" @@ -22,199 +22,204 @@ using namespace std::complex_literals; namespace xt { - TEST(xlapack, eigenvalues) + TEST_SUITE("xlapack") { - xarray eig_arg_0 = { - {0.89342434, 0.96630682, 0.83113658, 0.9014204, 0.17622395}, - {0.01114647, 0.93096724, 0.35509599, 0.35329223, 0.65759337}, - {0.27868701, 0.376794, 0.63310696, 0.90892131, 0.35454718}, - {0.02962539, 0.20561053, 0.2004051, 0.83641883, 0.08335324}, - {0.76958296, 0.23132089, 0.33539779, 0.70616527, 0.40256713} - }; - auto eig_res = xt::linalg::eig(eig_arg_0); - - xtensor, 1> eig_expected_0 = { - 2.24745601 + 0.i, - 0.24898158 + 0.51158566i, - 0.24898158 - 0.51158566i, - 0.66252212 + 0.i, - 0.28854321 + 0.i - }; - - xtensor, 2> eig_expected_1 = { - {-0.67843725 + 0.i, - -0.00104977 + 0.50731553i, - -0.00104977 - 0.50731553i, - -0.48456457 + 0.i, - -0.11153304 + 0.i}, - {-0.38393722 + 0.i, - -0.42892828 - 0.30675499i, - -0.42892828 + 0.30675499i, - -0.60497432 + 0.i, - -0.55233486 + 0.i}, - {-0.39453548 + 0.i, 0.10153693 - 0.12657944i, 0.10153693 + 0.12657944i, 0.35111489 + 0.i, 0.80267297 + 0.i - }, - {-0.15349367 + 0.i, - -0.04903747 + 0.08226059i, - -0.04903747 - 0.08226059i, - 0.48726345 + 0.i, - -0.10533951 + 0.i}, - {-0.46162383 + 0.i, 0.65501769 + 0.i, 0.65501769 - 0.i, -0.19620376 + 0.i, 0.16463982 + 0.i} - }; - xarray> eigvals = std::get<0>(eig_res); - xarray> eigvecs = std::get<1>(eig_res); - - EXPECT_TRUE(allclose(xt::imag(eigvals), xt::imag(eig_expected_0))); - EXPECT_TRUE(allclose(xt::real(eigvals), xt::real(eig_expected_0))); - EXPECT_TRUE(allclose(abs(imag(eigvecs)), abs(imag(eig_expected_1)))); - EXPECT_TRUE(allclose(abs(real(eigvecs)), abs(real(eig_expected_1)))); - } - - TEST(xlapack, generalized_eigenvalues) - { - xarray eig_arg_0 = { - {0.24, 0.39, 0.42, -0.16}, - {0.39, -0.11, 0.79, 0.63}, - {0.42, 0.79, -0.25, 0.48}, - {-0.16, 0.63, 0.48, -0.03} - }; - xarray eig_arg_1 = { - {4.16, -3.12, 0.56, -0.10}, - {-3.12, 5.03, -0.83, 1.09}, - {0.56, -0.83, 0.76, 0.34}, - {-0.10, 1.09, 0.34, 1.18} - }; - auto eig_res = xt::linalg::eigh(eig_arg_0, eig_arg_1); - xtensor eig_expected_0 = {-2.225448, -0.454756, 0.100076, 1.127039}; - xtensor eig_expected_1 = { - {0.031913, 0.327020, 0.682699, 0.425628}, - {0.265466, 0.565845, 0.056645, 0.520961}, - {0.713483, -0.371290, -0.077102, 0.714215}, - {-0.647650, -0.659561, -0.724409, -0.193227} - }; - xarray eigvals = std::get<0>(eig_res); - xarray eigvecs = std::get<1>(eig_res); - for (unsigned i = 0; i < 4; ++i) + TEST_CASE("eigenvalues") { - auto v = xt::view(eigvecs, xt::all(), i); - v /= xt::linalg::norm(v, 2); - if (v(0) < 0.0) - v = -v; + xarray eig_arg_0 = { + {0.89342434, 0.96630682, 0.83113658, 0.9014204, 0.17622395}, + {0.01114647, 0.93096724, 0.35509599, 0.35329223, 0.65759337}, + {0.27868701, 0.376794, 0.63310696, 0.90892131, 0.35454718}, + {0.02962539, 0.20561053, 0.2004051, 0.83641883, 0.08335324}, + {0.76958296, 0.23132089, 0.33539779, 0.70616527, 0.40256713} + }; + auto eig_res = xt::linalg::eig(eig_arg_0); + + xtensor, 1> eig_expected_0 = { + 2.24745601 + 0.i, + 0.24898158 + 0.51158566i, + 0.24898158 - 0.51158566i, + 0.66252212 + 0.i, + 0.28854321 + 0.i + }; + + xtensor, 2> eig_expected_1 = { + {-0.67843725 + 0.i, + -0.00104977 + 0.50731553i, + -0.00104977 - 0.50731553i, + -0.48456457 + 0.i, + -0.11153304 + 0.i}, + {-0.38393722 + 0.i, + -0.42892828 - 0.30675499i, + -0.42892828 + 0.30675499i, + -0.60497432 + 0.i, + -0.55233486 + 0.i}, + {-0.39453548 + 0.i, + 0.10153693 - 0.12657944i, + 0.10153693 + 0.12657944i, + 0.35111489 + 0.i, + 0.80267297 + 0.i}, + {-0.15349367 + 0.i, + -0.04903747 + 0.08226059i, + -0.04903747 - 0.08226059i, + 0.48726345 + 0.i, + -0.10533951 + 0.i}, + {-0.46162383 + 0.i, 0.65501769 + 0.i, 0.65501769 - 0.i, -0.19620376 + 0.i, 0.16463982 + 0.i} + }; + xarray> eigvals = std::get<0>(eig_res); + xarray> eigvecs = std::get<1>(eig_res); + + CHECK(allclose(xt::imag(eigvals), xt::imag(eig_expected_0))); + CHECK(allclose(xt::real(eigvals), xt::real(eig_expected_0))); + CHECK(allclose(abs(imag(eigvecs)), abs(imag(eig_expected_1)))); + CHECK(allclose(abs(real(eigvecs)), abs(real(eig_expected_1)))); } - EXPECT_TRUE(allclose(eigvals, eig_expected_0)); - EXPECT_TRUE(allclose(abs(eigvecs), abs(eig_expected_1))); - } + TEST_CASE("generalized_eigenvalues") + { + xarray eig_arg_0 = { + {0.24, 0.39, 0.42, -0.16}, + {0.39, -0.11, 0.79, 0.63}, + {0.42, 0.79, -0.25, 0.48}, + {-0.16, 0.63, 0.48, -0.03} + }; + xarray eig_arg_1 = { + {4.16, -3.12, 0.56, -0.10}, + {-3.12, 5.03, -0.83, 1.09}, + {0.56, -0.83, 0.76, 0.34}, + {-0.10, 1.09, 0.34, 1.18} + }; + auto eig_res = xt::linalg::eigh(eig_arg_0, eig_arg_1); + xtensor eig_expected_0 = {-2.225448, -0.454756, 0.100076, 1.127039}; + xtensor eig_expected_1 = { + {0.031913, 0.327020, 0.682699, 0.425628}, + {0.265466, 0.565845, 0.056645, 0.520961}, + {0.713483, -0.371290, -0.077102, 0.714215}, + {-0.647650, -0.659561, -0.724409, -0.193227} + }; + xarray eigvals = std::get<0>(eig_res); + xarray eigvecs = std::get<1>(eig_res); + for (unsigned i = 0; i < 4; ++i) + { + auto v = xt::view(eigvecs, xt::all(), i); + v /= xt::linalg::norm(v, 2); + if (v(0) < 0.0) + v = -v; + } + + CHECK(allclose(eigvals, eig_expected_0)); + CHECK(allclose(abs(eigvecs), abs(eig_expected_1))); + } - TEST(xlapack, inverse) - { - xarray a = {{2, 1, 1}, {-1, 1, -1}, {1, 2, 3}}; + TEST_CASE("inverse") + { + xarray a = {{2, 1, 1}, {-1, 1, -1}, {1, 2, 3}}; - xarray b = {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}; + xarray b = {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}; - auto t = linalg::inv(a); + auto t = linalg::inv(a); - xarray expected = { - {0.55555556, -0.11111111, -0.22222222}, - {0.22222222, 0.55555556, 0.11111111}, - {-0.33333333, -0.33333333, 0.33333333} - }; + xarray expected = { + {0.55555556, -0.11111111, -0.22222222}, + {0.22222222, 0.55555556, 0.11111111}, + {-0.33333333, -0.33333333, 0.33333333} + }; - EXPECT_TRUE(allclose(expected, t)); + CHECK(allclose(expected, t)); - auto br = linalg::inv(b); - EXPECT_EQ(b, br); - auto t_r_major = xarray::from_shape({3, 3}); - assign_data(t_r_major, t, true); - auto almost_eye = linalg::dot(t_r_major, a); - auto e = xt::eye(3); - auto d = almost_eye - e; - auto min = xt::amin(d); - EXPECT_NEAR(min(), 0.0, 1e-6); - } + auto br = linalg::inv(b); + CHECK_EQ(b, br); + auto t_r_major = xarray::from_shape({3, 3}); + assign_data(t_r_major, t, true); + auto almost_eye = linalg::dot(t_r_major, a); + auto e = xt::eye(3); + auto d = almost_eye - e; + auto min = xt::amin(d); + CHECK(min() == doctest::Approx(0.0).epsilon(1e-6)); + } - TEST(xlapack, single_element_inverse) - { - xtensor a = xt::ones({1, 1}); - auto res = linalg::inv(a); - EXPECT_EQ(res(), 1.); - } + TEST_CASE("single_element_inverse") + { + xtensor a = xt::ones({1, 1}); + auto res = linalg::inv(a); + CHECK_EQ(res(), 1.); + } - TEST(xlapack, solve) - { - xarray a = {{2, 1, 1}, {-1, 1, -1}, {1, 2, 3}}; + TEST_CASE("solve") + { + xarray a = {{2, 1, 1}, {-1, 1, -1}, {1, 2, 3}}; - xarray vec = {2, 3, -10}; - xarray expected = {3, 1, -5}; + xarray vec = {2, 3, -10}; + xarray expected = {3, 1, -5}; - auto res = linalg::solve(a, vec); - EXPECT_EQ(expected, res); + auto res = linalg::solve(a, vec); + CHECK_EQ(expected, res); - vec.reshape({3, 1}); - expected.reshape({3, 1}); - auto res2 = linalg::solve(a, vec); - EXPECT_EQ(expected, res2); + vec.reshape({3, 1}); + expected.reshape({3, 1}); + auto res2 = linalg::solve(a, vec); + CHECK_EQ(expected, res2); - xarray vec2 = {6, 2, -10}; - vec2.reshape({3, 1}); + xarray vec2 = {6, 2, -10}; + vec2.reshape({3, 1}); - auto res3 = linalg::solve(a, concatenate(xtuple(vec, vec2 * 3), 1)); - xarray expected3 = {{3, 16}, {1, 4}, {-5, -18}}; - EXPECT_EQ(expected3, res3); - } + auto res3 = linalg::solve(a, concatenate(xtuple(vec, vec2 * 3), 1)); + xarray expected3 = {{3, 16}, {1, 4}, {-5, -18}}; + CHECK_EQ(expected3, res3); + } - TEST(xlapack, solveCholesky) - { - xarray A = { - {1., 0., 0., 0., 0.}, - {0.44615865, 0.89495389, 0., 0., 0.}, - {0.39541532, 0.24253783, 0.88590187, 0., 0.}, - {-0.36681098, -0.26249522, 0.0338034, 0.89185386, 0.}, - {0.0881614, 0.12356345, 0.19887529, -0.35996807, 0.89879433} - }; - - xarray b = {1, 1, 1, -1, -1}; - auto x = linalg::solve_cholesky(A, b); - - const xarray x_expected = { - 0.13757507429403265, - 0.26609253571318064, - 1.03715526610177222, - -1.3449222878385465, - -1.81183493755905478 - }; - - for (std::size_t i = 0; i < x_expected.shape()[0]; ++i) + TEST_CASE("solveCholesky") { - EXPECT_NEAR(x_expected[i], x[i], 5e-16); + xarray A = { + {1., 0., 0., 0., 0.}, + {0.44615865, 0.89495389, 0., 0., 0.}, + {0.39541532, 0.24253783, 0.88590187, 0., 0.}, + {-0.36681098, -0.26249522, 0.0338034, 0.89185386, 0.}, + {0.0881614, 0.12356345, 0.19887529, -0.35996807, 0.89879433} + }; + + xarray b = {1, 1, 1, -1, -1}; + auto x = linalg::solve_cholesky(A, b); + + const xarray x_expected = { + 0.13757507429403265, + 0.26609253571318064, + 1.03715526610177222, + -1.3449222878385465, + -1.81183493755905478 + }; + + for (std::size_t i = 0; i < x_expected.shape()[0]; ++i) + { + CHECK(x_expected[i] == doctest::Approx(x[i]).epsilon(5e-16)); + } } - } - TEST(xlapack, solveTriangular) - { - const xt::xtensor A = { - {1., 0., 0., 0., 0.}, - {0.44615865, 0.89495389, 0., 0., 0.}, - {0.39541532, 0.24253783, 0.88590187, 0., 0.}, - {-0.36681098, -0.26249522, 0.0338034, 0.89185386, 0.}, - {0.0881614, 0.12356345, 0.19887529, -0.35996807, 0.89879433} - }; - - const xt::xtensor b = {0.38867999, 0.46467046, 0.39042938, -0.2736973, 0.20813322}; - auto x = linalg::solve_triangular(A, b); - - const xarray x_expected = { - 0.38867998999999998, - 0.32544416381003327, - 0.17813128230545805, - -0.05799057434472885, - 0.08606304705465571 - }; - - for (std::size_t i = 0; i < x_expected.shape()[0]; ++i) + TEST_CASE("solveTriangular") { - EXPECT_DOUBLE_EQ(x_expected[i], x[i]); + const xt::xtensor A = { + {1., 0., 0., 0., 0.}, + {0.44615865, 0.89495389, 0., 0., 0.}, + {0.39541532, 0.24253783, 0.88590187, 0., 0.}, + {-0.36681098, -0.26249522, 0.0338034, 0.89185386, 0.}, + {0.0881614, 0.12356345, 0.19887529, -0.35996807, 0.89879433} + }; + + const xt::xtensor b = {0.38867999, 0.46467046, 0.39042938, -0.2736973, 0.20813322}; + auto x = linalg::solve_triangular(A, b); + + const xarray x_expected = { + 0.38867998999999998, + 0.32544416381003327, + 0.17813128230545805, + -0.05799057434472885, + 0.08606304705465571 + }; + + for (std::size_t i = 0; i < x_expected.shape()[0]; ++i) + { + CHECK(x_expected[i] == doctest::Approx(x[i]).epsilon(5e-16)); + } } } - } // namespace xt diff --git a/test/test_linalg.cpp b/test/test_linalg.cpp index bb1327a..df00fca 100644 --- a/test/test_linalg.cpp +++ b/test/test_linalg.cpp @@ -14,7 +14,7 @@ #include "xtensor/misc/xcomplex.hpp" #include "xtensor/views/xview.hpp" -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor-blas/xblas.hpp" #include "xtensor-blas/xlapack.hpp" #include "xtensor-blas/xlinalg.hpp" @@ -23,605 +23,615 @@ using namespace std::complex_literals; namespace xt { - TEST(xlinalg, matrixpower) - { - xarray t1arg_0 = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}; - - auto t1res = xt::linalg::matrix_power(t1arg_0, 2); - xarray t1expected = {{15, 18, 21}, {42, 54, 66}, {69, 90, 111}}; - EXPECT_TRUE(allclose(t1res, t1expected)); - - auto t2res = xt::linalg::matrix_power(t1arg_0, 5); - xarray t2expected = {{32400, 41796, 51192}, {99468, 128304, 157140}, {166536, 214812, 263088}}; - EXPECT_TRUE(allclose(t2res, t2expected)); - - auto t3res = xt::linalg::matrix_power(t1arg_0, 41); - xarray t3expected = { - {1.06199622e+45, 1.36986674e+45, 1.67773727e+45}, - {3.26000325e+45, 4.20507151e+45, 5.15013977e+45}, - {5.45801029e+45, 7.04027628e+45, 8.62254226e+45} - }; - EXPECT_TRUE(allclose(t3res, t3expected)); - - xarray t4arg_0 = {{-2., 1., 3.}, {3., 2., 1.}, {1., 2., 5.}}; - - auto t4res = xt::linalg::matrix_power(t4arg_0, -2); - xarray t4expected = { - {0.09259259, -0.09259259, 0.01851852}, - {0.35185185, 0.64814815, -0.46296296}, - {-0.2037037, -0.2962963, 0.25925926} - }; - EXPECT_TRUE(allclose(t4res, t4expected)); - - auto t5res = xt::linalg::matrix_power(t4arg_0, -13); - xarray t5expected = { - {-0.02119919, -0.02993041, 0.02400524}, - {0.15202629, 0.21469317, -0.17217602}, - {-0.0726041, -0.10253451, 0.08222825} - }; - EXPECT_TRUE(allclose(t5res, t5expected)); - } +#define EXPECT_NEAR(a, b, c) CHECK(a == doctest::Approx(b).epsilon(c)) - TEST(xlinalg, det) + TEST_SUITE("xlinalg") { - xarray a = {{1, 2}, {3, 4}}; - double da = linalg::det(a); - EXPECT_EQ(da, -2.0); - xarray b = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}; - double db = linalg::det(b); - EXPECT_EQ(db, 0.0); - xarray c = {{12, 1, 2}, {3, 4, 5}, {6, 7, 8}}; - double dc = linalg::det(c); - EXPECT_NEAR(dc, -36, 1e-06); - - xarray> arg_0 = { - {0.95368636 + 0.32324664i, 0.49936872 + 0.22164004i, 0.30452434 + 0.78922905i}, - {0.84118920 + 0.59652768i, 0.42052057 + 0.97211559i, 0.19916742 + 0.83068058i}, - {0.67065616 + 0.56830636i, 0.00268706 + 0.29410473i, 0.69147455 + 0.7052149i} - }; - auto res = linalg::det(arg_0); - - auto res_i = std::imag(res); - auto res_r = std::real(res); - EXPECT_NEAR(0.4201495908415372, res_i, 1e-06); - EXPECT_NEAR(-0.07633013993862534, res_r, 1e-06); - } + TEST_CASE("matrixpower") + { + xarray t1arg_0 = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}; + + auto t1res = xt::linalg::matrix_power(t1arg_0, 2); + xarray t1expected = {{15, 18, 21}, {42, 54, 66}, {69, 90, 111}}; + CHECK(allclose(t1res, t1expected)); + + auto t2res = xt::linalg::matrix_power(t1arg_0, 5); + xarray t2expected = {{32400, 41796, 51192}, {99468, 128304, 157140}, {166536, 214812, 263088}}; + CHECK(allclose(t2res, t2expected)); + + auto t3res = xt::linalg::matrix_power(t1arg_0, 41); + xarray t3expected = { + {1.06199622e+45, 1.36986674e+45, 1.67773727e+45}, + {3.26000325e+45, 4.20507151e+45, 5.15013977e+45}, + {5.45801029e+45, 7.04027628e+45, 8.62254226e+45} + }; + CHECK(allclose(t3res, t3expected)); + + xarray t4arg_0 = {{-2., 1., 3.}, {3., 2., 1.}, {1., 2., 5.}}; + + auto t4res = xt::linalg::matrix_power(t4arg_0, -2); + xarray t4expected = { + {0.09259259, -0.09259259, 0.01851852}, + {0.35185185, 0.64814815, -0.46296296}, + {-0.2037037, -0.2962963, 0.25925926} + }; + CHECK(allclose(t4res, t4expected)); + + auto t5res = xt::linalg::matrix_power(t4arg_0, -13); + xarray t5expected = { + {-0.02119919, -0.02993041, 0.02400524}, + {0.15202629, 0.21469317, -0.17217602}, + {-0.0726041, -0.10253451, 0.08222825} + }; + CHECK(allclose(t5res, t5expected)); + } + + TEST_CASE("det") + { + xarray a = {{1, 2}, {3, 4}}; + double da = linalg::det(a); + CHECK_EQ(da, -2.0); + xarray b = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}; + double db = linalg::det(b); + CHECK_EQ(db, 0.0); + xarray c = {{12, 1, 2}, {3, 4, 5}, {6, 7, 8}}; + double dc = linalg::det(c); + EXPECT_NEAR(dc, -36, 1e-06); + + xarray> arg_0 = { + {0.95368636 + 0.32324664i, 0.49936872 + 0.22164004i, 0.30452434 + 0.78922905i}, + {0.84118920 + 0.59652768i, 0.42052057 + 0.97211559i, 0.19916742 + 0.83068058i}, + {0.67065616 + 0.56830636i, 0.00268706 + 0.29410473i, 0.69147455 + 0.7052149i} + }; + auto res = linalg::det(arg_0); + + auto res_i = std::imag(res); + auto res_r = std::real(res); + EXPECT_NEAR(0.4201495908415372, res_i, 1e-06); + EXPECT_NEAR(-0.07633013993862534, res_r, 1e-06); + } + + TEST_CASE("slogdet") + { + xarray> arg_0 = { + {0.13373658 + 0.43025551i, 0.42593478 + 0.17539337i, 0.18840853 + 0.24669458i}, + {0.82800224 + 0.11797823i, 0.40310379 + 0.14037109i, 0.88204561 + 0.96870283i}, + {0.35427657 + 0.1233739i, 0.22740960 + 0.94019582i, 0.05410180 + 0.86462543i} + }; + auto resc = linalg::slogdet(arg_0); + auto sc = std::get<0>(resc); + auto sl = std::real(std::get<1>(resc)); + auto scr = std::real(sc); + auto sci = std::imag(sc); + + EXPECT_NEAR(-0.8818794751329891, sl, 1e-06); + EXPECT_NEAR(0.8473375077176295, scr, 1e-06); + EXPECT_NEAR(0.5310547504870624, sci, 1e-06); + + xarray arg_b = { + {0.20009016, 0.33997118, 0.74433611}, + {0.52721448, 0.2449798, 0.49085606}, + {0.49757477, 0.97304175, 0.05011255} + }; + auto res = linalg::slogdet(arg_b); + double expected_0 = 1.0; + double expected_1 = -1.3017524147193602; + auto sres = std::get<0>(res); + auto lres = std::get<1>(res); + CHECK_EQ(expected_0, sres); + EXPECT_NEAR(expected_1, lres, 1e-06); + } + + TEST_CASE("svd") + { + xarray arg_0 = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}; + + auto res = linalg::svd(arg_0); + + xarray expected_0 = { + {-0.13511895, 0.90281571, 0.40824829}, + {-0.49633514, 0.29493179, -0.81649658}, + {-0.85755134, -0.31295213, 0.40824829} + }; + xarray expected_1 = {1.42267074e+01, 1.26522599e+00, 5.89938022e-16}; + xarray expected_2 = { + {-0.4663281, -0.57099079, -0.67565348}, + {-0.78477477, -0.08545673, 0.61386131}, + {-0.40824829, 0.81649658, -0.40824829} + }; + + CHECK(allclose(std::get<0>(res), expected_0)); + CHECK(allclose(std::get<1>(res), expected_1)); + CHECK(allclose(std::get<2>(res), expected_2)); + } + + TEST_CASE("svd_horizontal_vertical") + { + xarray a = xt::ones({3, 1}); + xarray b = xt::ones({1, 3}); + xarray u, s, vt; + + std::tie(u, s, vt) = linalg::svd(a, false); + CHECK(allclose(a, xt::linalg::dot(u * s, vt))); + + std::tie(u, s, vt) = linalg::svd(b, false); + CHECK(allclose(b, xt::linalg::dot(u * s, vt))); + } + + TEST_CASE("matrix_rank") + { + xarray eall = eye(4); + int a = linalg::matrix_rank(eall); + CHECK_EQ(4, a); + + xarray b = eye(4); + b(1, 1) = 0; + int rb = linalg::matrix_rank(b); + CHECK_EQ(3, rb); + xarray ones_arr = ones({4, 4}); + int ro = linalg::matrix_rank(ones_arr); + CHECK_EQ(1, ro); + xarray zarr = zeros({4, 4}); + int rz = linalg::matrix_rank(zarr); + CHECK_EQ(0, rz); + } + + TEST_CASE("eigh") + { + xarray arg_0 = {{-761., -208., -582.}, {-208., -623., -1605.5}, {-582., -1605.5, -476.}}; + auto res = xt::linalg::eigh(arg_0); + xarray expected_0 = {-2351.3290686, -609.79206435, 1101.12113295}; + xarray expected_1 = { + {-0.33220683, -0.93041946, -0.15478453}, + {-0.66309119, 0.34708777, -0.66320446}, + {-0.67078216, 0.11768479, 0.73225787} + }; + auto vals = std::get<0>(res); + auto vecs = std::get<1>(res); + CHECK(allclose(expected_0, vals)); + CHECK(allclose(expected_1, vecs)); + + auto vals_2 = xt::linalg::eigvalsh(arg_0); + CHECK(allclose(expected_0, vals_2)); + + xarray> complarg_0 = {{1. + 0.i, -0. - 2.i}, {0. + 2.i, 5. + 0.i}}; + auto complres = xt::linalg::eigh(complarg_0); + + xarray complexpected_0 = {0.17157288, 5.82842712}; + auto cmvals = std::get<0>(complres); + auto cmvecs = std::get<1>(complres); + CHECK(allclose(complexpected_0, cmvals)); + xarray, layout_type::column_major> complexpected_1 = { + {-0.92387953 + 0.i, -0.38268343 + 0.i}, + {0.00000000 + 0.38268343i, 0.00000000 - 0.92387953i} + }; + CHECK(allclose(imag(complexpected_1), imag(cmvecs))); + CHECK(allclose(real(complexpected_1), real(cmvecs))); + + auto cmvals2 = xt::linalg::eigvalsh(complarg_0); + CHECK(allclose(complexpected_0, cmvals2)); + } + + TEST_CASE("pinv") + { + xarray arg_0 = { + {1.47351391, 0.94686323, 0.92236842, -1.44141916, -1.53123963, -0.36949144}, + {-0.76686921, -0.01087083, -1.11100036, -0.59745592, -0.99849726, 0.45296729}, + {-0.35274989, -1.27760231, 1.50092545, -2.7243503, -0.79326768, -1.00826405}, + {0.05763039, 1.04069983, -0.502178, -1.01776144, 0.6496664, -0.2374513}, + {-1.45517735, 0.42523508, 0.41400096, 0.87164292, 1.87754145, 0.16358461}, + {1.07487297, -0.26417364, 1.82998799, 0.97985789, -0.74820612, -0.75097366}, + {0.91375249, 1.14211989, -0.23055478, -0.48264987, -0.4591723, 0.83185472}, + {0.05318152, -0.30014836, 1.68456715, 0.07388112, 0.0607432, -0.51529535}, + {-1.36227295, -0.12015569, -0.45599178, -1.07135129, -0.27405687, 0.50177945} + }; + auto res = xt::linalg::pinv(arg_0); + xarray expected = { + {-0.12524671, -0.41299325, 0.09108576, -0.07346514, -0.29603324, -0.1702256, 0.28503799, -0.08979346, -0.29415286 + }, + {0.38896886, 0.28355746, -0.26406943, 0.36828839, 0.29271933, 0.19368219, -0.11433648, 0.05102866, 0.11050527 + }, + {0.05803293, -0.11353613, 0.08736754, -0.20744326, 0.21647828, 0.11201996, 0.26187311, 0.24550066, 0.13766844 + }, + {0.01037663, 0.15710257, -0.24647364, -0.06611398, 0.0482986, 0.21835662, -0.18763349, 0.01747553, -0.02066564 + }, + {-0.23681835, -0.44982904, 0.14998634, 0.07608489, 0.04751429, -0.27436861, 0.18626331, -0.01520331, -0.18144752 + }, + {-0.25580995, -0.33606213, 0.10424938, -0.64100544, 0.01084618, -0.28963573, 0.89629794, 0.12987378, 0.22451995 + } + }; + CHECK(allclose(expected, res)); + + xarray> cmpl_arg_0 = { + {-0.32865615 + 1.56868725i, 0.28804396 + 0.52266479i}, + {-1.29703842 + 0.34647524i, -2.14982936 + 0.31425111i}, + {-0.69224750 - 1.36725801i, 2.22948403 + 1.4612309i} + }; + auto cmpl_res = xt::linalg::pinv(cmpl_arg_0); + xarray> cmpl_expected = { + {-0.06272312 - 0.24840107i, -0.20530381 - 0.00548715i, -0.14179276 + 0.16337684i}, + {0.05975312 - 0.0502577i, -0.17431091 - 0.05525696i, 0.16047967 - 0.14140846i} + }; + CHECK(allclose(real(cmpl_expected), real(cmpl_res))); + CHECK(allclose(imag(cmpl_expected), imag(cmpl_res))); + } + + TEST_CASE("pinv_small") + { + xt::xtensor d1{{1.f, 2.f}}; + auto r1 = xt::linalg::pinv(d1); + xt::xtensor e1 = {{0.2f}, {0.4f}}; + CHECK(allclose(r1, e1)); + + xt::xtensor d2{{1.f}}; + auto r2 = xt::linalg::pinv(d2); + CHECK_EQ(r2(0), 1.f); + } + + TEST_CASE("mat_norm") + { + xarray arg_0 = { + {0.06817001, 0.50274712, -0.36802027, -0.93123204}, + {-0.5990272, -0.67439921, -0.09397038, -1.55915724}, + {2.22694395, 0.59099048, -0.43162172, 0.19410077}, + {0.41859591, 1.68555153, 1.82660739, 1.24427635} + }; + auto res1 = xt::linalg::norm(arg_0, 1); + auto res2 = xt::linalg::norm(arg_0, linalg::normorder::frob); + auto res3 = xt::linalg::norm(arg_0, linalg::normorder::inf); + auto res4 = xt::linalg::norm(arg_0, linalg::normorder::neg_inf); + auto res5 = xt::linalg::norm(arg_0, linalg::normorder::nuc); + auto res6 = xt::linalg::norm(arg_0, 2); + double exp1 = 3.92876639061; + double exp2 = 4.23639347394; + double exp3 = 5.17503118283; + double exp4 = 1.87016943835; + double exp5 = 7.42677006218; + double exp6 = 3.29152325862; + + EXPECT_NEAR(exp1, res1, 1e-06); + EXPECT_NEAR(exp2, res2, 1e-06); + EXPECT_NEAR(exp3, res3, 1e-06); + EXPECT_NEAR(exp4, res4, 1e-06); + EXPECT_NEAR(exp5, res5, 1e-06); + EXPECT_NEAR(exp6, res6, 1e-06); + + xarray> cmplarg_0 = { + {0.40101756 + 0.71233018i, 0.62731701 + 0.42786349i, 0.32415089 + 0.2977805i}, + {0.24475928 + 0.49208478i, 0.69475518 + 0.74029639i, 0.59390240 + 0.35772892i}, + {0.63179202 + 0.41720995i, 0.44025718 + 0.65472131i, 0.08372648 + 0.37380143i} + }; + auto cmplres1 = xt::linalg::norm(cmplarg_0, 1); + auto cmplres2 = xt::linalg::norm(cmplarg_0, linalg::normorder::frob); + auto cmplres3 = xt::linalg::norm(cmplarg_0, linalg::normorder::inf); + auto cmplres4 = xt::linalg::norm(cmplarg_0, linalg::normorder::neg_inf); + auto cmplres5 = xt::linalg::norm(cmplarg_0, linalg::normorder::nuc); + auto cmplres6 = xt::linalg::norm(cmplarg_0, 2); + + double cmplexp1 = 2.56356133004; + double cmplexp2 = 2.14347558031; + double cmplexp3 = 2.25815855456; + double cmplexp4 = 1.92915797164; + double cmplexp5 = 2.77947580342; + double cmplexp6 = 2.0683368289; + + EXPECT_NEAR(cmplexp1, cmplres1, 1e-06); + EXPECT_NEAR(cmplexp2, cmplres2, 1e-06); + EXPECT_NEAR(cmplexp3, cmplres3, 1e-06); + EXPECT_NEAR(cmplexp4, cmplres4, 1e-06); + EXPECT_NEAR(cmplexp5, cmplres5, 1e-06); + EXPECT_NEAR(cmplexp6, cmplres6, 1e-06); + } + + TEST_CASE("vec_norm") + { + xarray arg_0 = {0.23451288, 0.98799529, 0.76599595, 0.77700444, 0.02798196}; + + EXPECT_NEAR(2.79349050582, xt::linalg::norm(arg_0, 1), 1e-6); + EXPECT_NEAR(1.49077149771, xt::linalg::norm(arg_0, 2), 1e-6); + EXPECT_NEAR(1.23766843269, xt::linalg::norm(arg_0, 3), 1e-6); + EXPECT_NEAR(1.13587319901, xt::linalg::norm(arg_0, 4), 1e-6); + EXPECT_NEAR(5.0, xt::linalg::norm(arg_0, 0), 1e-6); + EXPECT_NEAR(0.0229325662443, xt::linalg::norm(arg_0, -1), 1e-6); + EXPECT_NEAR(0.0277379546324, xt::linalg::norm(arg_0, -2), 1e-6); + EXPECT_NEAR(0.987995286517, xt::linalg::norm(arg_0, linalg::normorder::inf), 1e-6); + EXPECT_NEAR(0.0279819550429, xt::linalg::norm(arg_0, linalg::normorder::neg_inf), 1e-6); + + xarray> arg_1 = { + 0.23451288 + 0.77700444i, + 0.98799529 + 0.02798196i, + 0.76599595 + 0.17390652i + }; + EXPECT_NEAR(2.58550383197, xt::linalg::norm(arg_1, 1), 1e-06); + EXPECT_NEAR(1.50088078633, xt::linalg::norm(arg_1, 2), 1e-06); + EXPECT_NEAR(1.25673399279, xt::linalg::norm(arg_1, 3), 1e-06); + EXPECT_NEAR(1.15326879931, xt::linalg::norm(arg_1, 4), 1e-06); + EXPECT_NEAR(3.0, xt::linalg::norm(arg_1, 0), 1e-06); + EXPECT_NEAR(0.284338433895, xt::linalg::norm(arg_1, -1), 1e-06); + EXPECT_NEAR(0.490145522524, xt::linalg::norm(arg_1, -2), 1e-06); + EXPECT_NEAR(0.98839145888, xt::linalg::norm(arg_1, linalg::normorder::inf), 1e-06); + EXPECT_NEAR(0.785489192861, xt::linalg::norm(arg_1, linalg::normorder::neg_inf), 1e-06); + } + + TEST_CASE("vdot") + { + xarray arg_0 = {0.23451288, 0.98799529, 0.76599595, 0.77700444, 0.02798196}; + xarray arg_1 = {0.17390652, 0.15408224, 0.07708648, 0.8898657, 0.7503787}; + auto res = xt::linalg::vdot(arg_0, arg_1); + EXPECT_NEAR(0.964490439715, res, 1e-06); + + xarray> carg_0 = { + 0.23451288 + 0.17390652i, + 0.98799529 + 0.15408224i, + 0.76599595 + 0.07708648i, + 0.77700444 + 0.8898657i, + 0.02798196 + 0.7503787i + }; + xarray> carg_1 = { + 0.17390652 + 0.23451288i, + 0.15408224 + 0.98799529i, + 0.07708648 + 0.76599595i, + 0.88986570 + 0.77700444i, + 0.75037870 + 0.02798196i + }; + auto res_c = xt::linalg::vdot(carg_0, carg_1); + + EXPECT_NEAR(1.9289808794290355, std::real(res_c), 1e-06); + EXPECT_NEAR(0.8075433553117102, std::imag(res_c), 1e-06); + } + + TEST_CASE("kron") + { + xarray arg_0 = {{2, 1, 8}, {3, 5, 0}, {2, 6, 2}, {4, 4, 6}}; + + xarray arg_1 = {{3, 0, 6, 4, 7}, {6, 7, 1, 5, 7}}; + + auto res = xt::linalg::kron(arg_0, arg_1); + + xarray expected = { + {6, 0, 12, 8, 14, 3, 0, 6, 4, 7, 24, 0, 48, 32, 56}, + {12, 14, 2, 10, 14, 6, 7, 1, 5, 7, 48, 56, 8, 40, 56}, + {9, 0, 18, 12, 21, 15, 0, 30, 20, 35, 0, 0, 0, 0, 0}, + {18, 21, 3, 15, 21, 30, 35, 5, 25, 35, 0, 0, 0, 0, 0}, + {6, 0, 12, 8, 14, 18, 0, 36, 24, 42, 6, 0, 12, 8, 14}, + {12, 14, 2, 10, 14, 36, 42, 6, 30, 42, 12, 14, 2, 10, 14}, + {12, 0, 24, 16, 28, 12, 0, 24, 16, 28, 18, 0, 36, 24, 42}, + {24, 28, 4, 20, 28, 24, 28, 4, 20, 28, 36, 42, 6, 30, 42} + }; + + CHECK_EQ(expected, res); + } + + TEST_CASE("cholesky") + { + xarray arg_0 = {{4, 12, -16}, {12, 37, -43}, {-16, -43, 98}}; + + auto res = xt::linalg::cholesky(arg_0); + xarray expected = {{2., 0., 0.}, {6., 1., 0.}, {-8., 5., 3.}}; + CHECK_EQ(expected, res); + + xarray> cmplarg_0 = {{1. + 0.i, -0. - 2.i}, {0. + 2.i, 5. + 0.i}}; + auto cmplres = xt::linalg::cholesky(cmplarg_0); + xarray> cmplexpected = {{1. + 0.i, 0. + 0.i}, {0. + 2.i, 1. + 0.i}}; + CHECK_EQ(cmplexpected, cmplres); + } + + TEST_CASE("qr") + { + xarray a = xt::random::rand({9, 6}); + auto res = xt::linalg::qr(a); + xarray q = std::get<0>(res); + xarray r = std::get<1>(res); + auto resf = xt::linalg::qr(a, linalg::qrmode::complete); + auto resr = xt::linalg::qr(a, linalg::qrmode::r); + xarray qf = std::get<0>(resf); + xarray rf = std::get<1>(resf); + + auto neara = xt::linalg::dot(q, r); + CHECK(allclose(neara, a)); + auto nearaf = xt::linalg::dot(qf, rf); + CHECK(allclose(nearaf, a)); + + CHECK_EQ(std::get<1>(resr), xt::view(rf, xt::range(0, 6), xt::all())); + CHECK_EQ(std::get<0>(resr).size(), 0u); + CHECK_EQ(std::get<0>(resr).dimension(), 1u); + + xarray erawR = { + {-1.00444014e+01, 0.00000000e+00, 6.74440143e-01, 2.24813381e-01}, + {-9.58743044e+00, -1.25730337e+01, -6.22814365e-03, 3.37562246e-01}, + {-1.29027101e+01, -7.34080303e+00, -4.07831856e+00, -5.76331089e-01} + }; + + xarray eTau = {1.32854123, 1.79535299, 1.50132395}; + + xarray AA = + {{3.3, 1., 2.}, {0., 10., 8.}, {9., 7., 12.}, {3., 10., 5.}}; + + auto resraw = xt::linalg::qr(AA, linalg::qrmode::raw); + auto tau = std::get<1>(resraw); + auto rawR = std::get<0>(resraw); + + CHECK(allclose(tau, eTau)); + CHECK(allclose(erawR, rawR)); + } + + TEST_CASE("lstsq") + { + xarray arg_0 = {{0., 1.}, {1., 1.}, {2., 1.}, {3., 1.}}; + + xarray arg_1 = {{-1., 0.2, 0.9, 2.1}, {2., 3., 2., 1.}}; + arg_1 = transpose(arg_1); + auto res = xt::linalg::lstsq(arg_0, arg_1); + + xarray el_0 = {{1., -0.4}, {-0.95, 2.6}}; + xarray el_1 = {0.05, 1.2}; + int el_2 = 2; + xarray el_3 = {4.10003045, 1.09075677}; + + CHECK(allclose(el_0, std::get<0>(res))); + CHECK(allclose(el_1, std::get<1>(res))); + CHECK_EQ(el_2, std::get<2>(res)); + CHECK(allclose(el_3, std::get<3>(res))); - TEST(xlinalg, slogdet) - { - xarray> arg_0 = { - {0.13373658 + 0.43025551i, 0.42593478 + 0.17539337i, 0.18840853 + 0.24669458i}, - {0.82800224 + 0.11797823i, 0.40310379 + 0.14037109i, 0.88204561 + 0.96870283i}, - {0.35427657 + 0.1233739i, 0.22740960 + 0.94019582i, 0.05410180 + 0.86462543i} - }; - auto resc = linalg::slogdet(arg_0); - auto sc = std::get<0>(resc); - auto sl = std::real(std::get<1>(resc)); - auto scr = std::real(sc); - auto sci = std::imag(sc); - - EXPECT_NEAR(-0.8818794751329891, sl, 1e-06); - EXPECT_NEAR(0.8473375077176295, scr, 1e-06); - EXPECT_NEAR(0.5310547504870624, sci, 1e-06); - - xarray arg_b = { - {0.20009016, 0.33997118, 0.74433611}, - {0.52721448, 0.2449798, 0.49085606}, - {0.49757477, 0.97304175, 0.05011255} - }; - auto res = linalg::slogdet(arg_b); - double expected_0 = 1.0; - double expected_1 = -1.3017524147193602; - auto sres = std::get<0>(res); - auto lres = std::get<1>(res); - EXPECT_EQ(expected_0, sres); - EXPECT_NEAR(expected_1, lres, 1e-06); - } + xarray> carg_0 = {{0., 1.}, {1. - 3i, 1.}, {2., 1.}, {3., 1.}}; + xarray> carg_1 = {{-1., 0.2 + 4i, 0.9, 2.1 - 1i}, {2, 3i, 2, 1}}; + carg_1 = transpose(carg_1); + auto cres = xt::linalg::lstsq(carg_0, carg_1); - TEST(xlinalg, svd) - { - xarray arg_0 = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}; - - auto res = linalg::svd(arg_0); - - xarray expected_0 = { - {-0.13511895, 0.90281571, 0.40824829}, - {-0.49633514, 0.29493179, -0.81649658}, - {-0.85755134, -0.31295213, 0.40824829} - }; - xarray expected_1 = {1.42267074e+01, 1.26522599e+00, 5.89938022e-16}; - xarray expected_2 = { - {-0.4663281, -0.57099079, -0.67565348}, - {-0.78477477, -0.08545673, 0.61386131}, - {-0.40824829, 0.81649658, -0.40824829} - }; - - EXPECT_TRUE(allclose(std::get<0>(res), expected_0)); - EXPECT_TRUE(allclose(std::get<1>(res), expected_1)); - EXPECT_TRUE(allclose(std::get<2>(res), expected_2)); - } + xarray, layout_type::column_major> cel_0 = { + {-0.40425532 - 0.38723404i, -0.61702128 - 0.44680851i}, + {1.44680851 + 1.02765957i, 2.51063830 + 0.95744681i} + }; + xarray cel_1 = {16.11787234, 2.68085106}; + int cel_2 = 2; + xarray cel_3 = {5.01295356, 1.36758789}; - TEST(xlinalg, svd_horizontal_vertical) - { - xarray a = xt::ones({3, 1}); - xarray b = xt::ones({1, 3}); - xarray u, s, vt; + CHECK(allclose(imag(cel_0), imag(std::get<0>(cres)))); + CHECK(allclose(real(cel_0), real(std::get<0>(cres)))); + CHECK(allclose(cel_1, std::get<1>(cres))); + CHECK_EQ(cel_2, std::get<2>(cres)); + CHECK(allclose(cel_3, std::get<3>(cres))); + } - std::tie(u, s, vt) = linalg::svd(a, false); - EXPECT_TRUE(allclose(a, xt::linalg::dot(u * s, vt))); + TEST_CASE("trace") + { + auto e1 = eye(10); + xarray e2 = eye(5); - std::tie(u, s, vt) = linalg::svd(b, false); - EXPECT_TRUE(allclose(b, xt::linalg::dot(u * s, vt))); - } + auto t1 = linalg::trace(e1); + auto t11 = linalg::trace(e1, 1); + auto t1n1 = linalg::trace(e1, -1); + CHECK_EQ(10, t1()); + CHECK_EQ(0, t11()); + CHECK_EQ(0, t1n1()); - TEST(xlinalg, matrix_rank) - { - xarray eall = eye(4); - int a = linalg::matrix_rank(eall); - EXPECT_EQ(4, a); - - xarray b = eye(4); - b(1, 1) = 0; - int rb = linalg::matrix_rank(b); - EXPECT_EQ(3, rb); - xarray ones_arr = ones({4, 4}); - int ro = linalg::matrix_rank(ones_arr); - EXPECT_EQ(1, ro); - xarray zarr = zeros({4, 4}); - int rz = linalg::matrix_rank(zarr); - EXPECT_EQ(0, rz); - } + auto t2 = linalg::trace(e2); + auto t22 = linalg::trace(e2, 1); + CHECK_EQ(5, t2()); + CHECK_EQ(0, t22()); - TEST(xlinalg, eigh) - { - xarray arg_0 = {{-761., -208., -582.}, {-208., -623., -1605.5}, {-582., -1605.5, -476.}}; - auto res = xt::linalg::eigh(arg_0); - xarray expected_0 = {-2351.3290686, -609.79206435, 1101.12113295}; - xarray expected_1 = { - {-0.33220683, -0.93041946, -0.15478453}, - {-0.66309119, 0.34708777, -0.66320446}, - {-0.67078216, 0.11768479, 0.73225787} - }; - auto vals = std::get<0>(res); - auto vecs = std::get<1>(res); - EXPECT_TRUE(allclose(expected_0, vals)); - EXPECT_TRUE(allclose(expected_1, vecs)); - - auto vals_2 = xt::linalg::eigvalsh(arg_0); - EXPECT_TRUE(allclose(expected_0, vals_2)); - - xarray> complarg_0 = {{1. + 0.i, -0. - 2.i}, {0. + 2.i, 5. + 0.i}}; - auto complres = xt::linalg::eigh(complarg_0); - - xarray complexpected_0 = {0.17157288, 5.82842712}; - auto cmvals = std::get<0>(complres); - auto cmvecs = std::get<1>(complres); - EXPECT_TRUE(allclose(complexpected_0, cmvals)); - xarray, layout_type::column_major> complexpected_1 = { - {-0.92387953 + 0.i, -0.38268343 + 0.i}, - {0.00000000 + 0.38268343i, 0.00000000 - 0.92387953i} - }; - EXPECT_TRUE(allclose(imag(complexpected_1), imag(cmvecs))); - EXPECT_TRUE(allclose(real(complexpected_1), real(cmvecs))); - - auto cmvals2 = xt::linalg::eigvalsh(complarg_0); - EXPECT_TRUE(allclose(complexpected_0, cmvals2)); - } + xarray ar = xt::arange(9); + ar.reshape({3, 3}); - TEST(xlinalg, pinv) - { - xarray arg_0 = { - {1.47351391, 0.94686323, 0.92236842, -1.44141916, -1.53123963, -0.36949144}, - {-0.76686921, -0.01087083, -1.11100036, -0.59745592, -0.99849726, 0.45296729}, - {-0.35274989, -1.27760231, 1.50092545, -2.7243503, -0.79326768, -1.00826405}, - {0.05763039, 1.04069983, -0.502178, -1.01776144, 0.6496664, -0.2374513}, - {-1.45517735, 0.42523508, 0.41400096, 0.87164292, 1.87754145, 0.16358461}, - {1.07487297, -0.26417364, 1.82998799, 0.97985789, -0.74820612, -0.75097366}, - {0.91375249, 1.14211989, -0.23055478, -0.48264987, -0.4591723, 0.83185472}, - {0.05318152, -0.30014836, 1.68456715, 0.07388112, 0.0607432, -0.51529535}, - {-1.36227295, -0.12015569, -0.45599178, -1.07135129, -0.27405687, 0.50177945} - }; - auto res = xt::linalg::pinv(arg_0); - xarray expected = { - {-0.12524671, -0.41299325, 0.09108576, -0.07346514, -0.29603324, -0.1702256, 0.28503799, -0.08979346, -0.29415286 - }, - {0.38896886, 0.28355746, -0.26406943, 0.36828839, 0.29271933, 0.19368219, -0.11433648, 0.05102866, 0.11050527 - }, - {0.05803293, -0.11353613, 0.08736754, -0.20744326, 0.21647828, 0.11201996, 0.26187311, 0.24550066, 0.13766844 - }, - {0.01037663, 0.15710257, -0.24647364, -0.06611398, 0.0482986, 0.21835662, -0.18763349, 0.01747553, -0.02066564 - }, - {-0.23681835, -0.44982904, 0.14998634, 0.07608489, 0.04751429, -0.27436861, 0.18626331, -0.01520331, -0.18144752 - }, - {-0.25580995, -0.33606213, 0.10424938, -0.64100544, 0.01084618, -0.28963573, 0.89629794, 0.12987378, 0.22451995 - } - }; - EXPECT_TRUE(allclose(expected, res)); - - xarray> cmpl_arg_0 = { - {-0.32865615 + 1.56868725i, 0.28804396 + 0.52266479i}, - {-1.29703842 + 0.34647524i, -2.14982936 + 0.31425111i}, - {-0.69224750 - 1.36725801i, 2.22948403 + 1.4612309i} - }; - auto cmpl_res = xt::linalg::pinv(cmpl_arg_0); - xarray> cmpl_expected = { - {-0.06272312 - 0.24840107i, -0.20530381 - 0.00548715i, -0.14179276 + 0.16337684i}, - {0.05975312 - 0.0502577i, -0.17431091 - 0.05525696i, 0.16047967 - 0.14140846i} - }; - EXPECT_TRUE(allclose(real(cmpl_expected), real(cmpl_res))); - EXPECT_TRUE(allclose(imag(cmpl_expected), imag(cmpl_res))); - } + auto ar1 = linalg::trace(ar); + auto ar2 = linalg::trace(ar, 1); + auto ar3 = linalg::trace(ar, -1); - TEST(xlinalg, pinv_small) - { - xt::xtensor d1{{1.f, 2.f}}; - auto r1 = xt::linalg::pinv(d1); - xt::xtensor e1 = {{0.2f}, {0.4f}}; - EXPECT_TRUE(allclose(r1, e1)); - - xt::xtensor d2{{1.f}}; - auto r2 = xt::linalg::pinv(d2); - EXPECT_EQ(r2(0), 1.f); - } + CHECK_EQ(12, ar1()); + CHECK_EQ(6, ar2()); + CHECK_EQ(10, ar3()); + } - TEST(xlinalg, mat_norm) - { - xarray arg_0 = { - {0.06817001, 0.50274712, -0.36802027, -0.93123204}, - {-0.5990272, -0.67439921, -0.09397038, -1.55915724}, - {2.22694395, 0.59099048, -0.43162172, 0.19410077}, - {0.41859591, 1.68555153, 1.82660739, 1.24427635} - }; - auto res1 = xt::linalg::norm(arg_0, 1); - auto res2 = xt::linalg::norm(arg_0, linalg::normorder::frob); - auto res3 = xt::linalg::norm(arg_0, linalg::normorder::inf); - auto res4 = xt::linalg::norm(arg_0, linalg::normorder::neg_inf); - auto res5 = xt::linalg::norm(arg_0, linalg::normorder::nuc); - auto res6 = xt::linalg::norm(arg_0, 2); - double exp1 = 3.92876639061; - double exp2 = 4.23639347394; - double exp3 = 5.17503118283; - double exp4 = 1.87016943835; - double exp5 = 7.42677006218; - double exp6 = 3.29152325862; - - EXPECT_NEAR(exp1, res1, 1e-06); - EXPECT_NEAR(exp2, res2, 1e-06); - EXPECT_NEAR(exp3, res3, 1e-06); - EXPECT_NEAR(exp4, res4, 1e-06); - EXPECT_NEAR(exp5, res5, 1e-06); - EXPECT_NEAR(exp6, res6, 1e-06); - - xarray> cmplarg_0 = { - {0.40101756 + 0.71233018i, 0.62731701 + 0.42786349i, 0.32415089 + 0.2977805i}, - {0.24475928 + 0.49208478i, 0.69475518 + 0.74029639i, 0.59390240 + 0.35772892i}, - {0.63179202 + 0.41720995i, 0.44025718 + 0.65472131i, 0.08372648 + 0.37380143i} - }; - auto cmplres1 = xt::linalg::norm(cmplarg_0, 1); - auto cmplres2 = xt::linalg::norm(cmplarg_0, linalg::normorder::frob); - auto cmplres3 = xt::linalg::norm(cmplarg_0, linalg::normorder::inf); - auto cmplres4 = xt::linalg::norm(cmplarg_0, linalg::normorder::neg_inf); - auto cmplres5 = xt::linalg::norm(cmplarg_0, linalg::normorder::nuc); - auto cmplres6 = xt::linalg::norm(cmplarg_0, 2); - - double cmplexp1 = 2.56356133004; - double cmplexp2 = 2.14347558031; - double cmplexp3 = 2.25815855456; - double cmplexp4 = 1.92915797164; - double cmplexp5 = 2.77947580342; - double cmplexp6 = 2.0683368289; - - EXPECT_NEAR(cmplexp1, cmplres1, 1e-06); - EXPECT_NEAR(cmplexp2, cmplres2, 1e-06); - EXPECT_NEAR(cmplexp3, cmplres3, 1e-06); - EXPECT_NEAR(cmplexp4, cmplres4, 1e-06); - EXPECT_NEAR(cmplexp5, cmplres5, 1e-06); - EXPECT_NEAR(cmplexp6, cmplres6, 1e-06); - } + TEST_CASE("dots") + { + xarray arg_0 = { + {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, - TEST(xlinalg, vec_norm) - { - xarray arg_0 = {0.23451288, 0.98799529, 0.76599595, 0.77700444, 0.02798196}; - - EXPECT_NEAR(2.79349050582, xt::linalg::norm(arg_0, 1), 1e-6); - EXPECT_NEAR(1.49077149771, xt::linalg::norm(arg_0, 2), 1e-6); - EXPECT_NEAR(1.23766843269, xt::linalg::norm(arg_0, 3), 1e-6); - EXPECT_NEAR(1.13587319901, xt::linalg::norm(arg_0, 4), 1e-6); - EXPECT_NEAR(5.0, xt::linalg::norm(arg_0, 0), 1e-6); - EXPECT_NEAR(0.0229325662443, xt::linalg::norm(arg_0, -1), 1e-6); - EXPECT_NEAR(0.0277379546324, xt::linalg::norm(arg_0, -2), 1e-6); - EXPECT_NEAR(0.987995286517, xt::linalg::norm(arg_0, linalg::normorder::inf), 1e-6); - EXPECT_NEAR(0.0279819550429, xt::linalg::norm(arg_0, linalg::normorder::neg_inf), 1e-6); - - xarray> arg_1 = { - 0.23451288 + 0.77700444i, - 0.98799529 + 0.02798196i, - 0.76599595 + 0.17390652i - }; - EXPECT_NEAR(2.58550383197, xt::linalg::norm(arg_1, 1), 1e-06); - EXPECT_NEAR(1.50088078633, xt::linalg::norm(arg_1, 2), 1e-06); - EXPECT_NEAR(1.25673399279, xt::linalg::norm(arg_1, 3), 1e-06); - EXPECT_NEAR(1.15326879931, xt::linalg::norm(arg_1, 4), 1e-06); - EXPECT_NEAR(3.0, xt::linalg::norm(arg_1, 0), 1e-06); - EXPECT_NEAR(0.284338433895, xt::linalg::norm(arg_1, -1), 1e-06); - EXPECT_NEAR(0.490145522524, xt::linalg::norm(arg_1, -2), 1e-06); - EXPECT_NEAR(0.98839145888, xt::linalg::norm(arg_1, linalg::normorder::inf), 1e-06); - EXPECT_NEAR(0.785489192861, xt::linalg::norm(arg_1, linalg::normorder::neg_inf), 1e-06); - } + {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}} + }; - TEST(xlinalg, vdot) - { - xarray arg_0 = {0.23451288, 0.98799529, 0.76599595, 0.77700444, 0.02798196}; - xarray arg_1 = {0.17390652, 0.15408224, 0.07708648, 0.8898657, 0.7503787}; - auto res = xt::linalg::vdot(arg_0, arg_1); - EXPECT_NEAR(0.964490439715, res, 1e-06); - - xarray> carg_0 = { - 0.23451288 + 0.17390652i, - 0.98799529 + 0.15408224i, - 0.76599595 + 0.07708648i, - 0.77700444 + 0.8898657i, - 0.02798196 + 0.7503787i - }; - xarray> carg_1 = { - 0.17390652 + 0.23451288i, - 0.15408224 + 0.98799529i, - 0.07708648 + 0.76599595i, - 0.88986570 + 0.77700444i, - 0.75037870 + 0.02798196i - }; - auto res_c = xt::linalg::vdot(carg_0, carg_1); - - EXPECT_NEAR(1.9289808794290355, std::real(res_c), 1e-06); - EXPECT_NEAR(0.8075433553117102, std::imag(res_c), 1e-06); - } + xarray arg_1 = { + {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, - TEST(xlinalg, kron) - { - xarray arg_0 = {{2, 1, 8}, {3, 5, 0}, {2, 6, 2}, {4, 4, 6}}; + {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}}, - xarray arg_1 = {{3, 0, 6, 4, 7}, {6, 7, 1, 5, 7}}; + {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}} + }; - auto res = xt::linalg::kron(arg_0, arg_1); + auto res1 = xt::linalg::dot(arg_0, arg_1); + xarray expected1 = { + {{{15, 18, 21}, {42, 45, 48}, {69, 72, 75}}, - xarray expected = { - {6, 0, 12, 8, 14, 3, 0, 6, 4, 7, 24, 0, 48, 32, 56}, - {12, 14, 2, 10, 14, 6, 7, 1, 5, 7, 48, 56, 8, 40, 56}, - {9, 0, 18, 12, 21, 15, 0, 30, 20, 35, 0, 0, 0, 0, 0}, - {18, 21, 3, 15, 21, 30, 35, 5, 25, 35, 0, 0, 0, 0, 0}, - {6, 0, 12, 8, 14, 18, 0, 36, 24, 42, 6, 0, 12, 8, 14}, - {12, 14, 2, 10, 14, 36, 42, 6, 30, 42, 12, 14, 2, 10, 14}, - {12, 0, 24, 16, 28, 12, 0, 24, 16, 28, 18, 0, 36, 24, 42}, - {24, 28, 4, 20, 28, 24, 28, 4, 20, 28, 36, 42, 6, 30, 42} - }; + {{42, 54, 66}, {150, 162, 174}, {258, 270, 282}}, - EXPECT_EQ(expected, res); - } + {{69, 90, 111}, {258, 279, 300}, {447, 468, 489}}}, - TEST(xlinalg, cholesky) - { - xarray arg_0 = {{4, 12, -16}, {12, 37, -43}, {-16, -43, 98}}; + {{{96, 126, 156}, {366, 396, 426}, {636, 666, 696}}, - auto res = xt::linalg::cholesky(arg_0); - xarray expected = {{2., 0., 0.}, {6., 1., 0.}, {-8., 5., 3.}}; - EXPECT_EQ(expected, res); + {{123, 162, 201}, {474, 513, 552}, {825, 864, 903}}, - xarray> cmplarg_0 = {{1. + 0.i, -0. - 2.i}, {0. + 2.i, 5. + 0.i}}; - auto cmplres = xt::linalg::cholesky(cmplarg_0); - xarray> cmplexpected = {{1. + 0.i, 0. + 0.i}, {0. + 2.i, 1. + 0.i}}; - EXPECT_EQ(cmplexpected, cmplres); - } - - TEST(xlinalg, qr) - { - xarray a = xt::random::rand({9, 6}); - auto res = xt::linalg::qr(a); - xarray q = std::get<0>(res); - xarray r = std::get<1>(res); - auto resf = xt::linalg::qr(a, linalg::qrmode::complete); - auto resr = xt::linalg::qr(a, linalg::qrmode::r); - xarray qf = std::get<0>(resf); - xarray rf = std::get<1>(resf); - - auto neara = xt::linalg::dot(q, r); - EXPECT_TRUE(allclose(neara, a)); - auto nearaf = xt::linalg::dot(qf, rf); - EXPECT_TRUE(allclose(nearaf, a)); - - EXPECT_EQ(std::get<1>(resr), xt::view(rf, xt::range(0, 6), xt::all())); - EXPECT_EQ(std::get<0>(resr).size(), 0u); - EXPECT_EQ(std::get<0>(resr).dimension(), 1u); - - xarray erawR = { - {-1.00444014e+01, 0.00000000e+00, 6.74440143e-01, 2.24813381e-01}, - {-9.58743044e+00, -1.25730337e+01, -6.22814365e-03, 3.37562246e-01}, - {-1.29027101e+01, -7.34080303e+00, -4.07831856e+00, -5.76331089e-01} - }; - - xarray eTau = {1.32854123, 1.79535299, 1.50132395}; - - xarray AA = {{3.3, 1., 2.}, {0., 10., 8.}, {9., 7., 12.}, {3., 10., 5.}}; - - auto resraw = xt::linalg::qr(AA, linalg::qrmode::raw); - auto tau = std::get<1>(resraw); - auto rawR = std::get<0>(resraw); - - EXPECT_TRUE(allclose(tau, eTau)); - EXPECT_TRUE(allclose(erawR, rawR)); - } - - TEST(xlinalg, lstsq) - { - xarray arg_0 = {{0., 1.}, {1., 1.}, {2., 1.}, {3., 1.}}; - - xarray arg_1 = {{-1., 0.2, 0.9, 2.1}, {2., 3., 2., 1.}}; - arg_1 = transpose(arg_1); - auto res = xt::linalg::lstsq(arg_0, arg_1); - - xarray el_0 = {{1., -0.4}, {-0.95, 2.6}}; - xarray el_1 = {0.05, 1.2}; - int el_2 = 2; - xarray el_3 = {4.10003045, 1.09075677}; - - EXPECT_TRUE(allclose(el_0, std::get<0>(res))); - EXPECT_TRUE(allclose(el_1, std::get<1>(res))); - EXPECT_EQ(el_2, std::get<2>(res)); - EXPECT_TRUE(allclose(el_3, std::get<3>(res))); - - xarray> carg_0 = {{0., 1.}, {1. - 3i, 1.}, {2., 1.}, {3., 1.}}; - xarray> carg_1 = {{-1., 0.2 + 4i, 0.9, 2.1 - 1i}, {2, 3i, 2, 1}}; - carg_1 = transpose(carg_1); - auto cres = xt::linalg::lstsq(carg_0, carg_1); - - xarray, layout_type::column_major> cel_0 = { - {-0.40425532 - 0.38723404i, -0.61702128 - 0.44680851i}, - {1.44680851 + 1.02765957i, 2.51063830 + 0.95744681i} - }; - xarray cel_1 = {16.11787234, 2.68085106}; - int cel_2 = 2; - xarray cel_3 = {5.01295356, 1.36758789}; - - EXPECT_TRUE(allclose(imag(cel_0), imag(std::get<0>(cres)))); - EXPECT_TRUE(allclose(real(cel_0), real(std::get<0>(cres)))); - EXPECT_TRUE(allclose(cel_1, std::get<1>(cres))); - EXPECT_EQ(cel_2, std::get<2>(cres)); - EXPECT_TRUE(allclose(cel_3, std::get<3>(cres))); - } - - TEST(xlinalg, trace) - { - auto e1 = eye(10); - xarray e2 = eye(5); - - auto t1 = linalg::trace(e1); - auto t11 = linalg::trace(e1, 1); - auto t1n1 = linalg::trace(e1, -1); - EXPECT_EQ(10, t1()); - EXPECT_EQ(0, t11()); - EXPECT_EQ(0, t1n1()); - - auto t2 = linalg::trace(e2); - auto t22 = linalg::trace(e2, 1); - EXPECT_EQ(5, t2()); - EXPECT_EQ(0, t22()); - - xarray ar = xt::arange(9); - ar.reshape({3, 3}); - - auto ar1 = linalg::trace(ar); - auto ar2 = linalg::trace(ar, 1); - auto ar3 = linalg::trace(ar, -1); - - EXPECT_EQ(12, ar1()); - EXPECT_EQ(6, ar2()); - EXPECT_EQ(10, ar3()); - } - - TEST(xlinalg, dots) - { - xarray arg_0 = { - {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, - - {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}} - }; + {{150, 198, 246}, {582, 630, 678}, {1014, 1062, 1110}}} + }; - xarray arg_1 = { - {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}}, + CHECK(allclose(expected1, res1)); - {{9, 10, 11}, {12, 13, 14}, {15, 16, 17}}, + auto res2 = xt::linalg::dot(arg_1, arg_0); + xarray expected2 = { + {{{15, 18, 21}, {42, 45, 48}}, - {{18, 19, 20}, {21, 22, 23}, {24, 25, 26}} - }; + {{42, 54, 66}, {150, 162, 174}}, - auto res1 = xt::linalg::dot(arg_0, arg_1); - xarray expected1 = { - {{{15, 18, 21}, {42, 45, 48}, {69, 72, 75}}, + {{69, 90, 111}, {258, 279, 300}}}, - {{42, 54, 66}, {150, 162, 174}, {258, 270, 282}}, + {{{96, 126, 156}, {366, 396, 426}}, - {{69, 90, 111}, {258, 279, 300}, {447, 468, 489}}}, + {{123, 162, 201}, {474, 513, 552}}, - {{{96, 126, 156}, {366, 396, 426}, {636, 666, 696}}, + {{150, 198, 246}, {582, 630, 678}}}, - {{123, 162, 201}, {474, 513, 552}, {825, 864, 903}}, + {{{177, 234, 291}, {690, 747, 804}}, - {{150, 198, 246}, {582, 630, 678}, {1014, 1062, 1110}}} - }; + {{204, 270, 336}, {798, 864, 930}}, - EXPECT_TRUE(allclose(expected1, res1)); + {{231, 306, 381}, {906, 981, 1056}}} + }; - auto res2 = xt::linalg::dot(arg_1, arg_0); - xarray expected2 = { - {{{15, 18, 21}, {42, 45, 48}}, + CHECK(allclose(expected2, res2)); - {{42, 54, 66}, {150, 162, 174}}, + xarray arg_2 = {0, 1, 2}; + auto res3 = xt::linalg::dot(arg_0, arg_2); - {{69, 90, 111}, {258, 279, 300}}}, + xarray expected3 = {{5, 14, 23}, {32, 41, 50}}; - {{{96, 126, 156}, {366, 396, 426}}, + CHECK(allclose(expected3, res3)); - {{123, 162, 201}, {474, 513, 552}}, + auto res4 = xt::linalg::dot(arg_2, arg_0); - {{150, 198, 246}, {582, 630, 678}}}, + xarray expected4 = {{15, 18, 21}, {42, 45, 48}}; - {{{177, 234, 291}, {690, 747, 804}}, + CHECK(allclose(expected4, res4)); + } - {{204, 270, 336}, {798, 864, 930}}, + TEST_CASE("negative_strides") + { + xt::xarray A = {{2, 3}, {5, 7}, {11, 13}}; - {{231, 306, 381}, {906, 981, 1056}}} - }; + auto A1 = xt::view(A, xt::range(0, 3), 0); + auto A2 = xt::view(A, xt::range(-1, -4, -1), 1); - EXPECT_TRUE(allclose(expected2, res2)); + auto res = xt::linalg::dot(A1, A2); + CHECK_EQ(res(), 94); + } - xarray arg_2 = {0, 1, 2}; - auto res3 = xt::linalg::dot(arg_0, arg_2); + TEST_CASE("asserts") + { + CHECK_THROWS_AS(xt::linalg::eigh(xt::ones({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::eig(xt::ones({3, 1})), std::runtime_error); + CHECK_THROWS_AS( + xt::linalg::solve(xt::ones({3, 1}), xt::ones({3, 1})), + std::runtime_error + ); + CHECK_THROWS_AS(xt::linalg::inv(xt::ones({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::eigvals(xt::ones({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::eigvalsh(xt::ones({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::det(xt::ones({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::slogdet(xt::ones({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::cholesky(xt::ones({3, 1})), std::runtime_error); - xarray expected3 = {{5, 14, 23}, {32, 41, 50}}; - - EXPECT_TRUE(allclose(expected3, res3)); - - auto res4 = xt::linalg::dot(arg_2, arg_0); - - xarray expected4 = {{15, 18, 21}, {42, 45, 48}}; - - EXPECT_TRUE(allclose(expected4, res4)); - } - - TEST(xlinalg, negative_strides) - { - xt::xarray A = {{2, 3}, {5, 7}, {11, 13}}; - - auto A1 = xt::view(A, xt::range(0, 3), 0); - auto A2 = xt::view(A, xt::range(-1, -4, -1), 1); - - auto res = xt::linalg::dot(A1, A2); - EXPECT_EQ(res(), 94); - } - - TEST(xlinalg, asserts) - { - EXPECT_THROW(xt::linalg::eigh(xt::ones({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::eig(xt::ones({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::solve(xt::ones({3, 1}), xt::ones({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::inv(xt::ones({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::eigvals(xt::ones({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::eigvalsh(xt::ones({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::det(xt::ones({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::slogdet(xt::ones({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::cholesky(xt::ones({3, 1})), std::runtime_error); - - EXPECT_THROW(xt::linalg::eigh(xt::ones>({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::eig(xt::ones>({3, 1})), std::runtime_error); - EXPECT_THROW( - xt::linalg::solve(xt::ones>({3, 1}), xt::ones>({3, 1})), - std::runtime_error - ); - EXPECT_THROW(xt::linalg::inv(xt::ones>({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::eigvals(xt::ones>({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::eigvalsh(xt::ones>({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::det(xt::ones>({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::slogdet(xt::ones>({3, 1})), std::runtime_error); - EXPECT_THROW(xt::linalg::cholesky(xt::ones>({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::eigh(xt::ones>({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::eig(xt::ones>({3, 1})), std::runtime_error); + CHECK_THROWS_AS( + xt::linalg::solve(xt::ones>({3, 1}), xt::ones>({3, 1})), + std::runtime_error + ); + CHECK_THROWS_AS(xt::linalg::inv(xt::ones>({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::eigvals(xt::ones>({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::eigvalsh(xt::ones>({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::det(xt::ones>({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::slogdet(xt::ones>({3, 1})), std::runtime_error); + CHECK_THROWS_AS(xt::linalg::cholesky(xt::ones>({3, 1})), std::runtime_error); + } } +#undef EXPECT_NEAR } // namespace xt diff --git a/test/test_lstsq.cpp b/test/test_lstsq.cpp index 5ebe8a8..b218d1e 100644 --- a/test/test_lstsq.cpp +++ b/test/test_lstsq.cpp @@ -6,9 +6,8 @@ * * * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ +// This file is generated from test/files/cppy_source/test_lstsq.cppy by preprocess.py! -// This file is generated from test/files/cppy_source/test_lstsq.cppy by -// preprocess.py! #include @@ -19,355 +18,379 @@ #include "xtensor/views/xstrided_view.hpp" #include "xtensor/views/xview.hpp" -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor-blas/xlinalg.hpp" namespace xt { using namespace xt::placeholders; - /*py - a = np.random.random((6, 3)) - b = np.ones((6)) - */ - TEST(xtest_extended, lstsq1) + TEST_SUITE("xlstsq_extended") { - // py_a - xarray py_a = { - {0.3745401188473625, 0.9507143064099162, 0.7319939418114051}, - {0.5986584841970366, 0.1560186404424365, 0.1559945203362026}, - {0.0580836121681995, 0.8661761457749352, 0.6011150117432088}, - {0.7080725777960455, 0.0205844942958024, 0.9699098521619943}, - {0.8324426408004217, 0.2123391106782762, 0.1818249672071006}, - {0.1834045098534338, 0.3042422429595377, 0.5247564316322378} - }; - // py_b - xarray py_b = {1., 1., 1., 1., 1., 1.}; - // py_res0 = np.linalg.lstsq(a, b)[0] - xarray py_res0 = {0.99525656797683, 0.6379298291900684, 0.416589303565964}; - // py_res1 = np.linalg.lstsq(a, b)[1] - xarray py_res1 = {0.3378625895661748}; - // py_res2 = np.linalg.lstsq(a, b)[2] - int py_res2 = 3; - // py_res3 = np.linalg.lstsq(a, b)[3] - xarray py_res3 = {2.081504268698353, 1.012756249516551, 0.599044658280111}; + /*py + a = np.random.random((6, 3)) + b = np.ones((6)) + */ + TEST_CASE("lstsq1") + { + // py_a + xarray py_a = { + {0.3745401188473625, 0.9507143064099162, 0.7319939418114051}, + {0.5986584841970366, 0.1560186404424365, 0.1559945203362026}, + {0.0580836121681995, 0.8661761457749352, 0.6011150117432088}, + {0.7080725777960455, 0.0205844942958024, 0.9699098521619943}, + {0.8324426408004217, 0.2123391106782762, 0.1818249672071006}, + {0.1834045098534338, 0.3042422429595377, 0.5247564316322378} + }; + // py_b + xarray py_b = {1., 1., 1., 1., 1., 1.}; + // py_res0 = np.linalg.lstsq(a, b)[0] + xarray py_res0 = {0.9952565679768294, 0.6379298291900682, 0.4165893035659636}; + // py_res1 = np.linalg.lstsq(a, b)[1] + xarray py_res1 = {0.3378625895661749}; + // py_res2 = np.linalg.lstsq(a, b)[2] + int py_res2 = 3; + // py_res3 = np.linalg.lstsq(a, b)[3] + xarray py_res3 = {2.081504268698354, 1.0127562495165512, 0.5990446582801109}; - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } - /*py - a = np.random.random((3, 3)) - b = np.ones((3)) - */ - TEST(xtest_extended, lstsq20) - { - // py_a - xarray py_a = { - {0.4319450186421158, 0.2912291401980419, 0.6118528947223795}, - {0.1394938606520418, 0.2921446485352182, 0.3663618432936917}, - {0.4560699842170359, 0.7851759613930136, 0.1996737821583597} - }; - // py_b - xarray py_b = {1., 1., 1.}; - // py_res0 = np.linalg.lstsq(a, b)[0] - xarray py_res0 = {-1.655587220862159, 1.7320451450169407, 1.9787446378934206}; - // py_res1 = np.linalg.lstsq(a, b)[1] - xarray py_res1 = {}; - // py_res2 = np.linalg.lstsq(a, b)[2] - int py_res2 = 3; - // py_res3 = np.linalg.lstsq(a, b)[3] - xarray py_res3 = {1.2339483753871052, 0.4580824861786693, 0.1291723342275802}; + /*py + a = np.random.random((3, 3)) + b = np.ones((3)) + */ + TEST_CASE("lstsq20") + { + // py_a + xarray py_a = { + {0.4319450186421158, 0.2912291401980419, 0.6118528947223795}, + {0.1394938606520418, 0.2921446485352182, 0.3663618432936917}, + {0.4560699842170359, 0.7851759613930136, 0.1996737821583597} + }; + // py_b + xarray py_b = {1., 1., 1.}; + // py_res0 = np.linalg.lstsq(a, b)[0] + xarray py_res0 = {-1.655587220862159, 1.7320451450169407, 1.9787446378934204}; + // py_res1 = np.linalg.lstsq(a, b)[1] + xarray py_res1 = {}; + // py_res2 = np.linalg.lstsq(a, b)[2] + int py_res2 = 3; + // py_res3 = np.linalg.lstsq(a, b)[3] + xarray py_res3 = {1.2339483753871052, 0.4580824861786693, 0.1291723342275802}; - auto xres = xt::linalg::lstsq(py_a, py_b); + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } + /*py + a = np.random.random((3, 3)) + b = np.ones((3, 3)) + */ + TEST_CASE("lstsq21") + { + // py_a + xarray py_a = { + {0.5142344384136116, 0.5924145688620425, 0.0464504127199977}, + {0.6075448519014384, 0.1705241236872915, 0.0650515929852795}, + {0.9488855372533332, 0.9656320330745594, 0.8083973481164611} + }; + // py_b + xarray py_b = {{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}}; + // py_res0 = np.linalg.lstsq(a, b)[0] + xarray py_res0 = { + {1.6749237812267237, 1.6749237812267237, 1.6749237812267237}, + {0.3213797243357512, 0.3213797243357512, 0.3213797243357512}, + {-1.1128753832544371, -1.1128753832544371, -1.1128753832544371} + }; + // py_res1 = np.linalg.lstsq(a, b)[1] + xarray py_res1 = {}; + // py_res2 = np.linalg.lstsq(a, b)[2] + int py_res2 = 3; + // py_res3 = np.linalg.lstsq(a, b)[3] + xarray py_res3 = {1.8090476189892228, 0.4005423925178662, 0.2705890168670333}; - /*py - a = np.random.random((3, 3)) - b = np.ones((3, 3)) - */ - TEST(xtest_extended, lstsq21) - { - // py_a - xarray py_a = { - {0.5142344384136116, 0.5924145688620425, 0.0464504127199977}, - {0.6075448519014384, 0.1705241236872915, 0.0650515929852795}, - {0.9488855372533332, 0.9656320330745594, 0.8083973481164611} - }; - // py_b - xarray py_b = {{1., 1., 1.}, {1., 1., 1.}, {1., 1., 1.}}; - // py_res0 = np.linalg.lstsq(a, b)[0] - xarray py_res0 = { - {1.6749237812267237, 1.6749237812267237, 1.6749237812267237}, - {0.3213797243357512, 0.3213797243357512, 0.3213797243357512}, - {-1.1128753832544371, -1.1128753832544371, -1.1128753832544371} - }; - // py_res1 = np.linalg.lstsq(a, b)[1] - xarray py_res1 = {}; - // py_res2 = np.linalg.lstsq(a, b)[2] - int py_res2 = 3; - // py_res3 = np.linalg.lstsq(a, b)[3] - xarray py_res3 = {1.8090476189892228, 0.4005423925178662, 0.2705890168670333}; + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } - auto xres = xt::linalg::lstsq(py_a, py_b); + /*py + a = np.random.random((2, 5)) + b = np.ones((2)) + */ + TEST_CASE("lstsq3") + { + // py_a + xarray py_a = { + {0.3046137691733707, 0.0976721140063839, 0.6842330265121569, 0.4401524937396013, 0.1220382348447788 + }, + {0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169, 0.662522284353982} + }; + // py_b + xarray py_b = {1., 1.}; + // py_res0 = np.linalg.lstsq(a, b)[0] + xarray py_res0 = { + 0.3137661125421979, + 0.183749537801855, + 0.8404557593671863, + 0.7586648365305538, + -0.1845363594995904 + }; + // py_res1 = np.linalg.lstsq(a, b)[1] + xarray py_res1 = {}; + // py_res2 = np.linalg.lstsq(a, b)[2] + int py_res2 = 2; + // py_res3 = np.linalg.lstsq(a, b)[3] + xarray py_res3 = {1.4931292414997537, 0.3589512974668556}; - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } - /*py - a = np.random.random((2, 5)) - b = np.ones((2)) - */ - TEST(xtest_extended, lstsq3) - { - // py_a - xarray py_a = { - {0.3046137691733707, 0.0976721140063839, 0.6842330265121569, 0.4401524937396013, 0.1220382348447788}, - {0.4951769101112702, 0.0343885211152184, 0.9093204020787821, 0.2587799816000169, 0.662522284353982} - }; - // py_b - xarray py_b = {1., 1.}; - // py_res0 = np.linalg.lstsq(a, b)[0] - xarray py_res0 = - {0.3137661125421979, 0.183749537801855, 0.8404557593671863, 0.7586648365305537, -0.1845363594995904}; - // py_res1 = np.linalg.lstsq(a, b)[1] - xarray py_res1 = {}; - // py_res2 = np.linalg.lstsq(a, b)[2] - int py_res2 = 2; - // py_res3 = np.linalg.lstsq(a, b)[3] - xarray py_res3 = {1.4931292414997537, 0.3589512974668556}; + /*py + a = np.random.random((2, 5)) + b = np.ones((2, 10)) + */ + TEST_CASE("lstsq4") + { + // py_a + xarray py_a = { + {0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527, 0.9695846277645586 + }, + {0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851, 0.9218742350231168} + }; + // py_b + xarray py_b = { + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.} + }; + // py_res0 = np.linalg.lstsq(a, b)[0] + xarray py_res0 = { + {-0.0723929848964098, + -0.0723929848964098, + -0.0723929848964098, + -0.0723929848964098, + -0.0723929848964098, + -0.0723929848964098, + -0.0723929848964098, + -0.0723929848964098, + -0.0723929848964098, + -0.0723929848964098}, + {0.1423971374668719, + 0.1423971374668719, + 0.1423971374668719, + 0.1423971374668719, + 0.1423971374668719, + 0.1423971374668719, + 0.1423971374668719, + 0.1423971374668719, + 0.1423971374668719, + 0.1423971374668719}, + {0.2187317829605842, + 0.2187317829605842, + 0.2187317829605842, + 0.2187317829605842, + 0.2187317829605842, + 0.2187317829605842, + 0.2187317829605842, + 0.2187317829605842, + 0.2187317829605842, + 0.2187317829605842}, + {-0.1457627271119432, + -0.1457627271119432, + -0.1457627271119432, + -0.1457627271119432, + -0.1457627271119432, + -0.1457627271119432, + -0.1457627271119432, + -0.1457627271119432, + -0.1457627271119432, + -0.1457627271119432}, + {0.8827197220374988, + 0.8827197220374988, + 0.8827197220374988, + 0.8827197220374988, + 0.8827197220374988, + 0.8827197220374988, + 0.8827197220374988, + 0.8827197220374988, + 0.8827197220374988, + 0.8827197220374988} + }; + // py_res1 = np.linalg.lstsq(a, b)[1] + xarray py_res1 = {}; + // py_res2 = np.linalg.lstsq(a, b)[2] + int py_res2 = 2; + // py_res3 = np.linalg.lstsq(a, b)[3] + xarray py_res3 = {2.23042850951828, 0.3968910268428817}; - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } - /*py - a = np.random.random((2, 5)) - b = np.ones((2, 10)) - */ - TEST(xtest_extended, lstsq4) - { - // py_a - xarray py_a = { - {0.311711076089411, 0.5200680211778108, 0.5467102793432796, 0.184854455525527, 0.9695846277645586}, - {0.7751328233611146, 0.9394989415641891, 0.8948273504276488, 0.5978999788110851, 0.9218742350231168} - }; - // py_b - xarray py_b = {{1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, {1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}}; - // py_res0 = np.linalg.lstsq(a, b)[0] - xarray py_res0 = { - {-0.0723929848964098, - -0.0723929848964098, - -0.0723929848964098, - -0.0723929848964098, - -0.0723929848964098, - -0.0723929848964098, - -0.0723929848964098, - -0.0723929848964098, - -0.0723929848964098, - -0.0723929848964098}, - {0.1423971374668718, - 0.1423971374668718, - 0.1423971374668718, - 0.1423971374668718, - 0.1423971374668718, - 0.1423971374668718, - 0.1423971374668718, - 0.1423971374668718, - 0.1423971374668718, - 0.1423971374668718}, - {0.2187317829605842, - 0.2187317829605842, - 0.2187317829605842, - 0.2187317829605842, - 0.2187317829605842, - 0.2187317829605842, - 0.2187317829605842, - 0.2187317829605842, - 0.2187317829605842, - 0.2187317829605842}, - {-0.1457627271119433, - -0.1457627271119433, - -0.1457627271119433, - -0.1457627271119433, - -0.1457627271119433, - -0.1457627271119433, - -0.1457627271119433, - -0.1457627271119433, - -0.1457627271119433, - -0.1457627271119433}, - {0.882719722037499, - 0.882719722037499, - 0.882719722037499, - 0.882719722037499, - 0.882719722037499, - 0.882719722037499, - 0.882719722037499, - 0.882719722037499, - 0.882719722037499, - 0.882719722037499} - }; - // py_res1 = np.linalg.lstsq(a, b)[1] - xarray py_res1 = {}; - // py_res2 = np.linalg.lstsq(a, b)[2] - int py_res2 = 2; - // py_res3 = np.linalg.lstsq(a, b)[3] - xarray py_res3 = {2.23042850951828, 0.3968910268428817}; + /*py + a = np.random.random((10, 5)) + b = np.ones((10, 20)) + */ + TEST_CASE("lstsq5") + { + // py_a + xarray py_a = { + {0.0884925020519195, 0.1959828624191452, 0.0452272889105381, 0.3253303307632643, 0.388677289689482 + }, + {0.2713490317738959, 0.8287375091519293, 0.3567533266935893, 0.2809345096873808, 0.5426960831582485 + }, + {0.1409242249747626, 0.8021969807540397, 0.0745506436797708, 0.9868869366005173, 0.7722447692966574 + }, + {0.1987156815341724, 0.0055221171236024, 0.8154614284548342, 0.7068573438476171, 0.7290071680409873 + }, + {0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297, 0.8631034258755935 + }, + {0.6232981268275579, 0.3308980248526492, 0.0635583502860236, 0.3109823217156622, 0.325183322026747 + }, + {0.7296061783380641, 0.6375574713552131, 0.8872127425763265, 0.4722149251619493, 0.1195942459383017 + }, + {0.713244787222995, 0.7607850486168974, 0.5612771975694962, 0.770967179954561, 0.4937955963643907 + }, + {0.5227328293819941, 0.4275410183585496, 0.0254191267440952, 0.1078914269933045, 0.0314291856867343 + }, + {0.6364104112637804, 0.3143559810763267, 0.5085706911647028, 0.907566473926093, 0.2492922291488749} + }; + // py_b + xarray py_b = { + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.} + }; + // py_res0 = np.linalg.lstsq(a, b)[0] + xarray py_res0 = { + {0.7695214798127483, 0.7695214798127483, 0.7695214798127483, 0.7695214798127483, + 0.7695214798127483, 0.7695214798127483, 0.7695214798127483, 0.7695214798127483, + 0.7695214798127483, 0.7695214798127483, 0.7695214798127483, 0.7695214798127483, + 0.7695214798127483, 0.7695214798127483, 0.7695214798127483, 0.7695214798127483, + 0.7695214798127483, 0.7695214798127483, 0.7695214798127483, 0.7695214798127483}, + {0.3603784058338764, 0.3603784058338764, 0.3603784058338764, 0.3603784058338764, + 0.3603784058338764, 0.3603784058338764, 0.3603784058338764, 0.3603784058338764, + 0.3603784058338764, 0.3603784058338764, 0.3603784058338764, 0.3603784058338764, + 0.3603784058338764, 0.3603784058338764, 0.3603784058338764, 0.3603784058338764, + 0.3603784058338764, 0.3603784058338764, 0.3603784058338764, 0.3603784058338764}, + {-0.0288908468951096, -0.0288908468951096, -0.0288908468951096, -0.0288908468951096, + -0.0288908468951096, -0.0288908468951096, -0.0288908468951096, -0.0288908468951096, + -0.0288908468951096, -0.0288908468951096, -0.0288908468951096, -0.0288908468951096, + -0.0288908468951096, -0.0288908468951096, -0.0288908468951096, -0.0288908468951096, + -0.0288908468951096, -0.0288908468951096, -0.0288908468951096, -0.0288908468951096}, + {0.2739420182164652, 0.2739420182164652, 0.2739420182164652, 0.2739420182164652, + 0.2739420182164652, 0.2739420182164652, 0.2739420182164652, 0.2739420182164652, + 0.2739420182164652, 0.2739420182164652, 0.2739420182164652, 0.2739420182164652, + 0.2739420182164652, 0.2739420182164652, 0.2739420182164652, 0.2739420182164652, + 0.2739420182164651, 0.2739420182164651, 0.2739420182164651, 0.2739420182164651}, + {0.6381721647626308, 0.6381721647626308, 0.6381721647626308, 0.6381721647626308, + 0.6381721647626308, 0.6381721647626308, 0.6381721647626308, 0.6381721647626308, + 0.6381721647626308, 0.6381721647626308, 0.6381721647626308, 0.6381721647626308, + 0.6381721647626308, 0.6381721647626308, 0.6381721647626308, 0.6381721647626308, + 0.6381721647626308, 0.6381721647626308, 0.6381721647626308, 0.6381721647626308} + }; + // py_res1 = np.linalg.lstsq(a, b)[1] + xarray py_res1 = {0.668387503414133, 0.668387503414133, 0.668387503414133, + 0.668387503414133, 0.668387503414133, 0.668387503414133, + 0.668387503414133, 0.668387503414133, 0.668387503414133, + 0.668387503414133, 0.668387503414133, 0.668387503414133, + 0.668387503414133, 0.668387503414133, 0.668387503414133, + 0.668387503414133, 0.668387503414133, 0.668387503414133, + 0.668387503414133, 0.668387503414133}; + // py_res2 = np.linalg.lstsq(a, b)[2] + int py_res2 = 5; + // py_res3 = np.linalg.lstsq(a, b)[3] + xarray py_res3 = { + 3.317877520855451, + 1.0262463009257716, + 0.9696565206896538, + 0.8384020117545181, + 0.5915006407947914 + }; - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } - /*py - a = np.random.random((10, 5)) - b = np.ones((10, 20)) - */ - TEST(xtest_extended, lstsq5) - { - // py_a - xarray py_a = { - {0.0884925020519195, 0.1959828624191452, 0.0452272889105381, 0.3253303307632643, 0.388677289689482}, - {0.2713490317738959, 0.8287375091519293, 0.3567533266935893, 0.2809345096873808, 0.5426960831582485}, - {0.1409242249747626, 0.8021969807540397, 0.0745506436797708, 0.9868869366005173, 0.7722447692966574}, - {0.1987156815341724, 0.0055221171236024, 0.8154614284548342, 0.7068573438476171, 0.7290071680409873}, - {0.7712703466859457, 0.0740446517340904, 0.3584657285442726, 0.1158690595251297, 0.8631034258755935}, - {0.6232981268275579, 0.3308980248526492, 0.0635583502860236, 0.3109823217156622, 0.325183322026747}, - {0.7296061783380641, 0.6375574713552131, 0.8872127425763265, 0.4722149251619493, 0.1195942459383017}, - {0.713244787222995, 0.7607850486168974, 0.5612771975694962, 0.770967179954561, 0.4937955963643907}, - {0.5227328293819941, 0.4275410183585496, 0.0254191267440952, 0.1078914269933045, 0.0314291856867343}, - {0.6364104112637804, 0.3143559810763267, 0.5085706911647028, 0.907566473926093, 0.2492922291488749} - }; - // py_b - xarray py_b = { - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, - {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.} - }; - // py_res0 = np.linalg.lstsq(a, b)[0] - xarray py_res0 = { - {0.7695214798127482, 0.7695214798127482, 0.7695214798127482, 0.7695214798127482, - 0.7695214798127482, 0.7695214798127482, 0.7695214798127482, 0.7695214798127482, - 0.7695214798127482, 0.7695214798127482, 0.7695214798127482, 0.7695214798127482, - 0.7695214798127482, 0.7695214798127482, 0.7695214798127482, 0.7695214798127482, - 0.7695214798127483, 0.7695214798127483, 0.7695214798127483, 0.7695214798127483}, - {0.3603784058338763, 0.3603784058338763, 0.3603784058338763, 0.3603784058338763, - 0.3603784058338763, 0.3603784058338763, 0.3603784058338763, 0.3603784058338763, - 0.3603784058338763, 0.3603784058338763, 0.3603784058338763, 0.3603784058338763, - 0.3603784058338763, 0.3603784058338763, 0.3603784058338763, 0.3603784058338763, - 0.3603784058338762, 0.3603784058338762, 0.3603784058338763, 0.3603784058338763}, - {-0.0288908468951092, -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, - -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, - -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, - -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, -0.0288908468951092, - -0.0288908468951093, -0.0288908468951093, -0.0288908468951092, -0.0288908468951092}, - {0.2739420182164651, 0.2739420182164651, 0.2739420182164651, 0.2739420182164651, - 0.2739420182164651, 0.2739420182164651, 0.2739420182164651, 0.2739420182164651, - 0.2739420182164651, 0.2739420182164651, 0.2739420182164651, 0.2739420182164651, - 0.2739420182164651, 0.2739420182164651, 0.2739420182164651, 0.2739420182164651, - 0.2739420182164651, 0.2739420182164651, 0.2739420182164652, 0.2739420182164652}, - {0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, - 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, - 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, - 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, - 0.6381721647626307, 0.6381721647626307, 0.6381721647626307, 0.6381721647626307} - }; - // py_res1 = np.linalg.lstsq(a, b)[1] - xarray py_res1 = {0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331, 0.6683875034141331, - 0.6683875034141331, 0.6683875034141331}; - // py_res2 = np.linalg.lstsq(a, b)[2] - int py_res2 = 5; - // py_res3 = np.linalg.lstsq(a, b)[3] - xarray py_res3 = - {3.317877520855451, 1.0262463009257718, 0.9696565206896536, 0.8384020117545181, 0.5915006407947916}; - - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } - - /*py - a = np.array([[0., 1.]]) - b = np.array([1.]) - */ - TEST(xtest_extended, lstsq6) - { - // py_a - xarray py_a = {{0., 1.}}; - // py_b - xarray py_b = {1.}; - // py_res0 = np.linalg.lstsq(a, b)[0] - xarray py_res0 = {0., 1.}; - // py_res1 = np.linalg.lstsq(a, b)[1] - xarray py_res1 = {}; - // py_res2 = np.linalg.lstsq(a, b)[2] - int py_res2 = 1; - // py_res3 = np.linalg.lstsq(a, b)[3] - xarray py_res3 = {1.}; + /*py + a = np.array([[0., 1.]]) + b = np.array([1.]) + */ + TEST_CASE("lstsq6") + { + // py_a + xarray py_a = {{0., 1.}}; + // py_b + xarray py_b = {1.}; + // py_res0 = np.linalg.lstsq(a, b)[0] + xarray py_res0 = {0., 1.}; + // py_res1 = np.linalg.lstsq(a, b)[1] + xarray py_res1 = {}; + // py_res2 = np.linalg.lstsq(a, b)[2] + int py_res2 = 1; + // py_res3 = np.linalg.lstsq(a, b)[3] + xarray py_res3 = {1.}; - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); - } + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } - /*py - a = np.array([[1.], [1.]]) - b = np.array([1., 1.]) - */ - TEST(xtest_extended, lstsq7) - { - // cannot use "// py_a" due to ambiguous initializer list conversion below - // xarray py_a = {{1.}, - // {1.}}; - xarray py_a = xt::ones({2, 1}); - // py_b - xarray py_b = {1., 1.}; - // py_res0 = np.linalg.lstsq(a, b)[0] - xarray py_res0 = {0.9999999999999997}; - // py_res1 = np.linalg.lstsq(a, b)[1] - xarray py_res1 = {2.2508083912556065e-33}; - // py_res2 = np.linalg.lstsq(a, b)[2] - int py_res2 = 1; - // py_res3 = np.linalg.lstsq(a, b)[3] - xarray py_res3 = {1.4142135623730951}; + /*py + a = np.array([[1.], [1.]]) + b = np.array([1., 1.]) + */ + TEST_CASE("lstsq7") + { + // cannot use "// py_a" due to ambiguous initializer list conversion below + // xarray py_a = {{1.}, + // {1.}}; + xarray py_a = xt::ones({2, 1}); + // py_b + xarray py_b = {1., 1.}; + // py_res0 = np.linalg.lstsq(a, b)[0] + xarray py_res0 = {0.9999999999999997}; + // py_res1 = np.linalg.lstsq(a, b)[1] + xarray py_res1 = {2.2508083912556065e-33}; + // py_res2 = np.linalg.lstsq(a, b)[2] + int py_res2 = 1; + // py_res3 = np.linalg.lstsq(a, b)[3] + xarray py_res3 = {1.4142135623730951}; - auto xres = xt::linalg::lstsq(py_a, py_b); - EXPECT_TRUE(xt::allclose(std::get<0>(xres), py_res0)); - EXPECT_TRUE(xt::allclose(std::get<1>(xres), py_res1)); - EXPECT_EQ(std::get<2>(xres), py_res2); - EXPECT_TRUE(xt::allclose(std::get<3>(xres), py_res3)); + auto xres = xt::linalg::lstsq(py_a, py_b); + CHECK(xt::allclose(std::get<0>(xres), py_res0)); + CHECK(xt::allclose(std::get<1>(xres), py_res1)); + CHECK_EQ(std::get<2>(xres), py_res2); + CHECK(xt::allclose(std::get<3>(xres), py_res3)); + } } - -} // namespace xt +} diff --git a/test/test_qr.cpp b/test/test_qr.cpp index 183f676..541bb92 100644 --- a/test/test_qr.cpp +++ b/test/test_qr.cpp @@ -6,9 +6,8 @@ * * * The full license is in the file LICENSE, distributed with this software. * ****************************************************************************/ +// This file is generated from test/files/cppy_source/test_qr.cppy by preprocess.py! -// This file is generated from test/files/cppy_source/test_qr.cppy by -// preprocess.py! #include @@ -19,422 +18,432 @@ #include "xtensor/views/xstrided_view.hpp" #include "xtensor/views/xview.hpp" -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor-blas/xlinalg.hpp" namespace xt { using namespace xt::placeholders; - /*py - a = np.random.random((6, 3)) - res_q1 = np.linalg.qr(a, 'raw') - res_q2 = np.linalg.qr(a, 'complete') - res_q3 = np.linalg.qr(a, 'reduced') - res_q4 = np.linalg.qr(a, 'r') - */ - TEST(xtest_extended, qr1) + TEST_SUITE("xqr_extended") { - // py_a - xarray py_a = { - {0.3745401188473625, 0.9507143064099162, 0.7319939418114051}, - {0.5986584841970366, 0.1560186404424365, 0.1559945203362026}, - {0.0580836121681995, 0.8661761457749352, 0.6011150117432088}, - {0.7080725777960455, 0.0205844942958024, 0.9699098521619943}, - {0.8324426408004217, 0.2123391106782762, 0.1818249672071006}, - {0.1834045098534338, 0.3042422429595377, 0.5247564316322378} - }; - // py_resq1_h = res_q1[0] - xarray py_resq1_h = { - {-1.3152987216651169, - 0.3542695728401418, - 0.0343722790456067, - 0.4190178144924799, - 0.4926165861757361, - 0.1085337284576868}, - {-0.567877094797874, - 1.2223138676385652, - -0.5073775633545011, - 0.3838046167052855, - 0.3339455785740943, - -0.0869071101793681}, - {-1.0163710885529547, - 0.7215655008695085, - 0.7854784971183756, - -0.8184018010449023, - 0.3355103841692941, - -0.2743559826773574} - }; - // py_resq1_tau = res_q1[1] - xarray py_resq1_tau = {1.2847566964660388, 1.3124991842889797, 1.0766465015522177}; + /*py + a = np.random.random((6, 3)) + res_q1 = np.linalg.qr(a, 'raw') + res_q2 = np.linalg.qr(a, 'complete') + res_q3 = np.linalg.qr(a, 'reduced') + res_q4 = np.linalg.qr(a, 'r') + */ + TEST_CASE("qr1") + { + // py_a + xarray py_a = { + {0.3745401188473625, 0.9507143064099162, 0.7319939418114051}, + {0.5986584841970366, 0.1560186404424365, 0.1559945203362026}, + {0.0580836121681995, 0.8661761457749352, 0.6011150117432088}, + {0.7080725777960455, 0.0205844942958024, 0.9699098521619943}, + {0.8324426408004217, 0.2123391106782762, 0.1818249672071006}, + {0.1834045098534338, 0.3042422429595377, 0.5247564316322378} + }; + // py_resq1_h = res_q1[0] + xarray py_resq1_h = { + {-1.3152987216651169, + 0.3542695728401418, + 0.0343722790456067, + 0.4190178144924799, + 0.4926165861757361, + 0.1085337284576868}, + {-0.567877094797874, + 1.2223138676385652, + -0.507377563354501, + 0.3838046167052855, + 0.3339455785740943, + -0.0869071101793681}, + {-1.0163710885529547, + 0.7215655008695085, + 0.7854784971183754, + -0.8184018010449026, + 0.3355103841692942, + -0.2743559826773575} + }; + // py_resq1_tau = res_q1[1] + xarray py_resq1_tau = {1.2847566964660388, 1.3124991842889797, 1.0766465015522177}; - auto res1 = linalg::qr(py_a, linalg::qrmode::raw); - EXPECT_TRUE(allclose(std::get<0>(res1), py_resq1_h)); - EXPECT_TRUE(allclose(std::get<1>(res1), py_resq1_tau)); + auto res1 = linalg::qr(py_a, linalg::qrmode::raw); + CHECK(allclose(std::get<0>(res1), py_resq1_h)); + CHECK(allclose(std::get<1>(res1), py_resq1_tau)); - // py_resq2_q_cmpl = res_q2[0] - xarray py_resq2_q_cmpl = { - {-0.2847566964660388, - 0.6455031901264903, - -0.0295327810119745, - -0.5849049416686276, - -0.0730618203174815, - -0.3923203408230155}, - {-0.4551502060605353, - -0.0838170448559192, - -0.3133472182914374, - 0.0819245453270295, - -0.7892351407115685, - 0.2408791714587237}, - {-0.0441600156766425, - 0.6881200538051699, - 0.0760152664601146, - 0.7143224973945711, - 0.0235700722943727, - 0.0891638112668338}, - {-0.538335943107778, - -0.2332659103773061, - 0.7525061466150679, - 0.1447692100263401, - -0.0279639819291248, - -0.2603378924852557}, - {-0.6328924578795164, - -0.1203177215897514, - -0.4769214096589269, - 0.1040507467269481, - 0.5878955555305321, - 0.0326957112268427}, - {-0.1394394344284399, 0.1841243791750922, 0.31850193596774, -0.3303532438685529, 0.1575155429538277, 0.8433664457979998 - } - }; - // py_resq2_r_cmpl = res_q2[1] - xarray py_resq2_r_cmpl = { - {-1.3152987216651169, -0.567877094797874, -1.0163710885529547}, - {0., 1.2223138676385652, 0.7215655008695085}, - {0., 0., 0.7854784971183756}, - {0., 0., 0.}, - {0., 0., 0.}, - {0., 0., 0.} - }; + // py_resq2_q_cmpl = res_q2[0] + xarray py_resq2_q_cmpl = { + {-0.2847566964660388, + 0.6455031901264903, + -0.0295327810119745, + -0.5849049416686276, + -0.0730618203174815, + -0.3923203408230155}, + {-0.4551502060605353, + -0.0838170448559192, + -0.3133472182914375, + 0.0819245453270296, + -0.7892351407115688, + 0.2408791714587238}, + {-0.0441600156766425, + 0.6881200538051697, + 0.0760152664601147, + 0.7143224973945713, + 0.0235700722943726, + 0.0891638112668339}, + {-0.538335943107778, + -0.2332659103773061, + 0.7525061466150681, + 0.1447692100263398, + -0.0279639819291247, + -0.2603378924852559}, + {-0.6328924578795164, + -0.1203177215897514, + -0.4769214096589271, + 0.1040507467269484, + 0.5878955555305321, + 0.0326957112268428}, + {-0.1394394344284399, + 0.1841243791750922, + 0.3185019359677401, + -0.330353243868553, + 0.1575155429538277, + 0.8433664457979998} + }; + // py_resq2_r_cmpl = res_q2[1] + xarray py_resq2_r_cmpl = { + {-1.3152987216651169, -0.567877094797874, -1.0163710885529547}, + {0., 1.2223138676385652, 0.7215655008695085}, + {0., 0., 0.7854784971183754}, + {0., 0., 0.}, + {0., 0., 0.}, + {0., 0., 0.} + }; - auto res2 = linalg::qr(py_a, linalg::qrmode::complete); - EXPECT_TRUE(allclose(std::get<0>(res2), py_resq2_q_cmpl)); - EXPECT_TRUE(allclose(std::get<1>(res2), py_resq2_r_cmpl)); + auto res2 = linalg::qr(py_a, linalg::qrmode::complete); + CHECK(allclose(std::get<0>(res2), py_resq2_q_cmpl)); + CHECK(allclose(std::get<1>(res2), py_resq2_r_cmpl)); - // py_resq3_q_cmpl = res_q3[0] - xarray py_resq3_q_cmpl = { - {-0.2847566964660388, 0.6455031901264903, -0.0295327810119745}, - {-0.4551502060605353, -0.0838170448559192, -0.3133472182914374}, - {-0.0441600156766425, 0.6881200538051699, 0.0760152664601146}, - {-0.538335943107778, -0.2332659103773061, 0.7525061466150679}, - {-0.6328924578795164, -0.1203177215897514, -0.4769214096589269}, - {-0.1394394344284399, 0.1841243791750922, 0.31850193596774} - }; - // py_resq3_r_cmpl = res_q3[1] - xarray py_resq3_r_cmpl = { - {-1.3152987216651169, -0.567877094797874, -1.0163710885529547}, - {0., 1.2223138676385652, 0.7215655008695085}, - {0., 0., 0.7854784971183756} - }; + // py_resq3_q_cmpl = res_q3[0] + xarray py_resq3_q_cmpl = { + {-0.2847566964660388, 0.6455031901264903, -0.0295327810119745}, + {-0.4551502060605353, -0.0838170448559192, -0.3133472182914375}, + {-0.0441600156766425, 0.6881200538051697, 0.0760152664601147}, + {-0.538335943107778, -0.2332659103773061, 0.7525061466150681}, + {-0.6328924578795164, -0.1203177215897514, -0.4769214096589271}, + {-0.1394394344284399, 0.1841243791750922, 0.3185019359677401} + }; + // py_resq3_r_cmpl = res_q3[1] + xarray py_resq3_r_cmpl = { + {-1.3152987216651169, -0.567877094797874, -1.0163710885529547}, + {0., 1.2223138676385652, 0.7215655008695085}, + {0., 0., 0.7854784971183754} + }; - auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); - EXPECT_TRUE(allclose(std::get<0>(res3), py_resq3_q_cmpl)); - EXPECT_TRUE(allclose(std::get<1>(res3), py_resq3_r_cmpl)); + auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); + CHECK(allclose(std::get<0>(res3), py_resq3_q_cmpl)); + CHECK(allclose(std::get<1>(res3), py_resq3_r_cmpl)); - // py_resq4_r_r = res_q4 - xarray py_resq4_r_r = { - {-1.3152987216651169, -0.567877094797874, -1.0163710885529547}, - {0., 1.2223138676385652, 0.7215655008695085}, - {0., 0., 0.7854784971183756} - }; + // py_resq4_r_r = res_q4 + xarray py_resq4_r_r = { + {-1.3152987216651169, -0.567877094797874, -1.0163710885529547}, + {0., 1.2223138676385652, 0.7215655008695085}, + {0., 0., 0.7854784971183754} + }; - auto res4 = linalg::qr(py_a, linalg::qrmode::r); - EXPECT_TRUE(allclose(std::get<1>(res4), py_resq4_r_r)); - } + auto res4 = linalg::qr(py_a, linalg::qrmode::r); + CHECK(allclose(std::get<1>(res4), py_resq4_r_r)); + } - /*py - a = np.random.random((5, 10)) - res_q1 = np.linalg.qr(a, 'raw') - res_q2 = np.linalg.qr(a, 'complete') - res_q3 = np.linalg.qr(a, 'reduced') - res_q4 = np.linalg.qr(a, 'r') - */ - TEST(xtest_extended, qr2) - { - // py_a - xarray py_a = { - {0.4319450186421158, - 0.2912291401980419, - 0.6118528947223795, - 0.1394938606520418, - 0.2921446485352182, - 0.3663618432936917, - 0.4560699842170359, - 0.7851759613930136, - 0.1996737821583597, - 0.5142344384136116}, - {0.5924145688620425, - 0.0464504127199977, - 0.6075448519014384, - 0.1705241236872915, - 0.0650515929852795, - 0.9488855372533332, - 0.9656320330745594, - 0.8083973481164611, - 0.3046137691733707, - 0.0976721140063839}, - {0.6842330265121569, - 0.4401524937396013, - 0.1220382348447788, - 0.4951769101112702, - 0.0343885211152184, - 0.9093204020787821, - 0.2587799816000169, - 0.662522284353982, - 0.311711076089411, - 0.5200680211778108}, - {0.5467102793432796, - 0.184854455525527, - 0.9695846277645586, - 0.7751328233611146, - 0.9394989415641891, - 0.8948273504276488, - 0.5978999788110851, - 0.9218742350231168, - 0.0884925020519195, - 0.1959828624191452}, - {0.0452272889105381, - 0.3253303307632643, - 0.388677289689482, - 0.2713490317738959, - 0.8287375091519293, - 0.3567533266935893, - 0.2809345096873808, - 0.5426960831582485, - 0.1409242249747626, - 0.8021969807540397} - }; - // py_resq1_h = res_q1[0] - xarray py_resq1_h = { - {-1.1430852952870696, 0.3761289948662397, 0.4344253062693247, 0.3471109568548026, 0.0287151863113738}, - {-0.4988738747365853, 0.4145384440977922, -0.1456730968857621, 0.1343802288038163, -0.4549175132696516 - }, - {-1.0982282164248067, 0.0432498341745755, 0.8009723247566577, -0.2697221220857602, -0.2118640849148783 - }, - {-0.8189559577243967, 0.2159221672678357, 0.2467828455102148, -0.4358731022610104, 0.0126894274012747 - }, - {-0.6468222288756241, 0.5399745339753013, 0.9011434603476536, -0.3516828694145329, 0.1205612964483228 - }, - {-1.6166030169206462, 0.0627336303098124, 0.1745159258713335, -0.1676233275811677, 0.3369911999240203 - }, - {-1.1247642047094615, -0.1631138338388988, 0.4469666475320985, 0.229673631977487, 0.3155802843315489}, - {-1.5746170823854417, 0.2876936477590398, 0.5186696050660637, 0.0972324032495854, 0.1124970816045023}, - {-0.4678059691956431, 0.0924634343088705, -0.0398310260167535, 0.1199094213119632, 0.1189824829973467 - }, - {-0.6817147826175952, 0.8209704648352938, 0.1936105292921998, 0.1556371881989978, 0.1610633542281174} - }; - // py_resq1_tau = res_q1[1] - xarray py_resq1_tau = - {1.3778764545594464, 1.6048419481909388, 1.7894907284949315, 1.9996780087119976, 0.}; + /*py + a = np.random.random((5, 10)) + res_q1 = np.linalg.qr(a, 'raw') + res_q2 = np.linalg.qr(a, 'complete') + res_q3 = np.linalg.qr(a, 'reduced') + res_q4 = np.linalg.qr(a, 'r') + */ + TEST_CASE("qr2") + { + // py_a + xarray py_a = { + {0.4319450186421158, + 0.2912291401980419, + 0.6118528947223795, + 0.1394938606520418, + 0.2921446485352182, + 0.3663618432936917, + 0.4560699842170359, + 0.7851759613930136, + 0.1996737821583597, + 0.5142344384136116}, + {0.5924145688620425, + 0.0464504127199977, + 0.6075448519014384, + 0.1705241236872915, + 0.0650515929852795, + 0.9488855372533332, + 0.9656320330745594, + 0.8083973481164611, + 0.3046137691733707, + 0.0976721140063839}, + {0.6842330265121569, + 0.4401524937396013, + 0.1220382348447788, + 0.4951769101112702, + 0.0343885211152184, + 0.9093204020787821, + 0.2587799816000169, + 0.662522284353982, + 0.311711076089411, + 0.5200680211778108}, + {0.5467102793432796, + 0.184854455525527, + 0.9695846277645586, + 0.7751328233611146, + 0.9394989415641891, + 0.8948273504276488, + 0.5978999788110851, + 0.9218742350231168, + 0.0884925020519195, + 0.1959828624191452}, + {0.0452272889105381, + 0.3253303307632643, + 0.388677289689482, + 0.2713490317738959, + 0.8287375091519293, + 0.3567533266935893, + 0.2809345096873808, + 0.5426960831582485, + 0.1409242249747626, + 0.8021969807540397} + }; + // py_resq1_h = res_q1[0] + xarray py_resq1_h = { + {-1.1430852952870696, 0.3761289948662397, 0.4344253062693247, 0.3471109568548026, 0.0287151863113738 + }, + {-0.4988738747365855, 0.4145384440977923, -0.1456730968857619, 0.1343802288038164, -0.4549175132696515 + }, + {-1.0982282164248067, 0.0432498341745755, 0.8009723247566577, -0.2697221220857602, -0.2118640849148782 + }, + {-0.8189559577243967, 0.2159221672678355, 0.2467828455102149, -0.4358731022610104, 0.0126894274012749 + }, + {-0.6468222288756241, 0.5399745339753012, 0.9011434603476536, -0.351682869414533, 0.120561296448323 + }, + {-1.6166030169206462, 0.0627336303098122, 0.1745159258713337, -0.1676233275811678, 0.3369911999240203 + }, + {-1.1247642047094615, -0.1631138338388989, 0.4469666475320985, 0.2296736319774871, 0.3155802843315489 + }, + {-1.5746170823854422, 0.2876936477590399, 0.5186696050660639, 0.0972324032495857, 0.1124970816045022 + }, + {-0.4678059691956431, 0.0924634343088704, -0.0398310260167535, 0.1199094213119632, 0.1189824829973467 + }, + {-0.6817147826175952, 0.820970464835294, 0.1936105292921997, 0.1556371881989975, 0.1610633542281176} + }; + // py_resq1_tau = res_q1[1] + xarray py_resq1_tau = + {1.3778764545594464, 1.604841948190939, 1.7894907284949315, 1.9996780087119976, 0.}; - auto res1 = linalg::qr(py_a, linalg::qrmode::raw); - EXPECT_TRUE(allclose(std::get<0>(res1), py_resq1_h)); - EXPECT_TRUE(allclose(std::get<1>(res1), py_resq1_tau)); - // py_resq2_q_cmpl = res_q2[0] - xarray py_resq2_q_cmpl = { - {-0.3778764545594464, 0.2477850983490842, 0.2323946032026168, 0.6442783657634198, -0.5715855718413764 - }, - {-0.5182592859033026, -0.5116427882060655, 0.0755411199296714, 0.3718390470559857, 0.5706647283090043 - }, - {-0.5985844007732788, 0.3414264138444295, -0.6868036045376356, -0.2311050279755064, 0.0039992397751236 - }, - {-0.4782760145698324, -0.1296501056095487, 0.5617369601749853, -0.6259003290333343, -0.217125009365484 - }, - {-0.0395659791067296, 0.737185903526121, 0.3912015899373974, 0.0384753772446496, 0.5481536630508244} - }; - // py_resq2_r_cmpl = res_q2[1] - xarray py_resq2_r_cmpl = { - {-1.1430852952870696, - -0.4988738747365853, - -1.0982282164248067, - -0.8189559577243967, - -0.6468222288756241, - -1.6166030169206462, - -1.1247642047094615, - -1.5746170823854417, - -0.4678059691956431, - -0.6817147826175952}, - {0., - 0.4145384440977922, - 0.0432498341745755, - 0.2159221672678357, - 0.5399745339753013, - 0.0627336303098124, - -0.1631138338388988, - 0.2876936477590398, - 0.0924634343088705, - 0.8209704648352938}, - {0., - 0., - 0.8009723247566577, - 0.2467828455102148, - 0.9011434603476536, - 0.1745159258713335, - 0.4469666475320985, - 0.5186696050660637, - -0.0398310260167535, - 0.1936105292921998}, - {0., - 0., - 0., - -0.4358731022610104, - -0.3516828694145329, - -0.1676233275811677, - 0.229673631977487, - 0.0972324032495854, - 0.1199094213119632, - 0.1556371881989978}, - {0., - 0., - 0., - 0., - 0.1205612964483228, - 0.3369911999240203, - 0.3155802843315489, - 0.1124970816045023, - 0.1189824829973467, - 0.1610633542281174} - }; + auto res1 = linalg::qr(py_a, linalg::qrmode::raw); + CHECK(allclose(std::get<0>(res1), py_resq1_h)); + CHECK(allclose(std::get<1>(res1), py_resq1_tau)); + // py_resq2_q_cmpl = res_q2[0] + xarray py_resq2_q_cmpl = { + {-0.3778764545594464, 0.2477850983490846, 0.2323946032026168, 0.6442783657634201, -0.571585571841376 + }, + {-0.5182592859033026, -0.5116427882060656, 0.0755411199296714, 0.3718390470559858, 0.570664728309004 + }, + {-0.5985844007732788, 0.3414264138444293, -0.6868036045376356, -0.2311050279755065, 0.0039992397751237 + }, + {-0.4782760145698324, -0.1296501056095487, 0.5617369601749854, -0.625900329033334, -0.2171250093654842 + }, + {-0.0395659791067296, 0.7371859035261209, 0.3912015899373973, 0.0384753772446493, 0.5481536630508248} + }; + // py_resq2_r_cmpl = res_q2[1] + xarray py_resq2_r_cmpl = { + {-1.1430852952870696, + -0.4988738747365855, + -1.0982282164248067, + -0.8189559577243967, + -0.6468222288756241, + -1.6166030169206462, + -1.1247642047094615, + -1.5746170823854422, + -0.4678059691956431, + -0.6817147826175952}, + {0., + 0.4145384440977923, + 0.0432498341745755, + 0.2159221672678355, + 0.5399745339753012, + 0.0627336303098122, + -0.1631138338388989, + 0.2876936477590399, + 0.0924634343088704, + 0.820970464835294}, + {0., + 0., + 0.8009723247566577, + 0.2467828455102149, + 0.9011434603476536, + 0.1745159258713337, + 0.4469666475320985, + 0.5186696050660639, + -0.0398310260167535, + 0.1936105292921997}, + {0., + 0., + 0., + -0.4358731022610104, + -0.351682869414533, + -0.1676233275811678, + 0.2296736319774871, + 0.0972324032495857, + 0.1199094213119632, + 0.1556371881989975}, + {0., + 0., + 0., + 0., + 0.120561296448323, + 0.3369911999240203, + 0.3155802843315489, + 0.1124970816045022, + 0.1189824829973467, + 0.1610633542281176} + }; - auto res2 = linalg::qr(py_a, linalg::qrmode::complete); - EXPECT_TRUE(allclose(std::get<0>(res2), py_resq2_q_cmpl)); - EXPECT_TRUE(allclose(std::get<1>(res2), py_resq2_r_cmpl)); + auto res2 = linalg::qr(py_a, linalg::qrmode::complete); + CHECK(allclose(std::get<0>(res2), py_resq2_q_cmpl)); + CHECK(allclose(std::get<1>(res2), py_resq2_r_cmpl)); - // py_resq3_q_cmpl = res_q3[0] - xarray py_resq3_q_cmpl = { - {-0.3778764545594464, 0.2477850983490842, 0.2323946032026168, 0.6442783657634198, -0.5715855718413764 - }, - {-0.5182592859033026, -0.5116427882060655, 0.0755411199296714, 0.3718390470559857, 0.5706647283090043 - }, - {-0.5985844007732788, 0.3414264138444295, -0.6868036045376356, -0.2311050279755064, 0.0039992397751236 - }, - {-0.4782760145698324, -0.1296501056095487, 0.5617369601749853, -0.6259003290333343, -0.217125009365484 - }, - {-0.0395659791067296, 0.737185903526121, 0.3912015899373974, 0.0384753772446496, 0.5481536630508244} - }; - // py_resq3_r_cmpl = res_q3[1] - xarray py_resq3_r_cmpl = { - {-1.1430852952870696, - -0.4988738747365853, - -1.0982282164248067, - -0.8189559577243967, - -0.6468222288756241, - -1.6166030169206462, - -1.1247642047094615, - -1.5746170823854417, - -0.4678059691956431, - -0.6817147826175952}, - {0., - 0.4145384440977922, - 0.0432498341745755, - 0.2159221672678357, - 0.5399745339753013, - 0.0627336303098124, - -0.1631138338388988, - 0.2876936477590398, - 0.0924634343088705, - 0.8209704648352938}, - {0., - 0., - 0.8009723247566577, - 0.2467828455102148, - 0.9011434603476536, - 0.1745159258713335, - 0.4469666475320985, - 0.5186696050660637, - -0.0398310260167535, - 0.1936105292921998}, - {0., - 0., - 0., - -0.4358731022610104, - -0.3516828694145329, - -0.1676233275811677, - 0.229673631977487, - 0.0972324032495854, - 0.1199094213119632, - 0.1556371881989978}, - {0., - 0., - 0., - 0., - 0.1205612964483228, - 0.3369911999240203, - 0.3155802843315489, - 0.1124970816045023, - 0.1189824829973467, - 0.1610633542281174} - }; + // py_resq3_q_cmpl = res_q3[0] + xarray py_resq3_q_cmpl = { + {-0.3778764545594464, 0.2477850983490846, 0.2323946032026168, 0.6442783657634201, -0.571585571841376 + }, + {-0.5182592859033026, -0.5116427882060656, 0.0755411199296714, 0.3718390470559858, 0.570664728309004 + }, + {-0.5985844007732788, 0.3414264138444293, -0.6868036045376356, -0.2311050279755065, 0.0039992397751237 + }, + {-0.4782760145698324, -0.1296501056095487, 0.5617369601749854, -0.625900329033334, -0.2171250093654842 + }, + {-0.0395659791067296, 0.7371859035261209, 0.3912015899373973, 0.0384753772446493, 0.5481536630508248} + }; + // py_resq3_r_cmpl = res_q3[1] + xarray py_resq3_r_cmpl = { + {-1.1430852952870696, + -0.4988738747365855, + -1.0982282164248067, + -0.8189559577243967, + -0.6468222288756241, + -1.6166030169206462, + -1.1247642047094615, + -1.5746170823854422, + -0.4678059691956431, + -0.6817147826175952}, + {0., + 0.4145384440977923, + 0.0432498341745755, + 0.2159221672678355, + 0.5399745339753012, + 0.0627336303098122, + -0.1631138338388989, + 0.2876936477590399, + 0.0924634343088704, + 0.820970464835294}, + {0., + 0., + 0.8009723247566577, + 0.2467828455102149, + 0.9011434603476536, + 0.1745159258713337, + 0.4469666475320985, + 0.5186696050660639, + -0.0398310260167535, + 0.1936105292921997}, + {0., + 0., + 0., + -0.4358731022610104, + -0.351682869414533, + -0.1676233275811678, + 0.2296736319774871, + 0.0972324032495857, + 0.1199094213119632, + 0.1556371881989975}, + {0., + 0., + 0., + 0., + 0.120561296448323, + 0.3369911999240203, + 0.3155802843315489, + 0.1124970816045022, + 0.1189824829973467, + 0.1610633542281176} + }; - auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); - EXPECT_TRUE(allclose(std::get<0>(res3), py_resq3_q_cmpl)); - EXPECT_TRUE(allclose(std::get<1>(res3), py_resq3_r_cmpl)); + auto res3 = linalg::qr(py_a, linalg::qrmode::reduced); + CHECK(allclose(std::get<0>(res3), py_resq3_q_cmpl)); + CHECK(allclose(std::get<1>(res3), py_resq3_r_cmpl)); - // py_resq4_r_r = res_q4 - xarray py_resq4_r_r = { - {-1.1430852952870696, - -0.4988738747365853, - -1.0982282164248067, - -0.8189559577243967, - -0.6468222288756241, - -1.6166030169206462, - -1.1247642047094615, - -1.5746170823854417, - -0.4678059691956431, - -0.6817147826175952}, - {0., - 0.4145384440977922, - 0.0432498341745755, - 0.2159221672678357, - 0.5399745339753013, - 0.0627336303098124, - -0.1631138338388988, - 0.2876936477590398, - 0.0924634343088705, - 0.8209704648352938}, - {0., - 0., - 0.8009723247566577, - 0.2467828455102148, - 0.9011434603476536, - 0.1745159258713335, - 0.4469666475320985, - 0.5186696050660637, - -0.0398310260167535, - 0.1936105292921998}, - {0., - 0., - 0., - -0.4358731022610104, - -0.3516828694145329, - -0.1676233275811677, - 0.229673631977487, - 0.0972324032495854, - 0.1199094213119632, - 0.1556371881989978}, - {0., - 0., - 0., - 0., - 0.1205612964483228, - 0.3369911999240203, - 0.3155802843315489, - 0.1124970816045023, - 0.1189824829973467, - 0.1610633542281174} - }; + // py_resq4_r_r = res_q4 + xarray py_resq4_r_r = { + {-1.1430852952870696, + -0.4988738747365855, + -1.0982282164248067, + -0.8189559577243967, + -0.6468222288756241, + -1.6166030169206462, + -1.1247642047094615, + -1.5746170823854422, + -0.4678059691956431, + -0.6817147826175952}, + {0., + 0.4145384440977923, + 0.0432498341745755, + 0.2159221672678355, + 0.5399745339753012, + 0.0627336303098122, + -0.1631138338388989, + 0.2876936477590399, + 0.0924634343088704, + 0.820970464835294}, + {0., + 0., + 0.8009723247566577, + 0.2467828455102149, + 0.9011434603476536, + 0.1745159258713337, + 0.4469666475320985, + 0.5186696050660639, + -0.0398310260167535, + 0.1936105292921997}, + {0., + 0., + 0., + -0.4358731022610104, + -0.351682869414533, + -0.1676233275811678, + 0.2296736319774871, + 0.0972324032495857, + 0.1199094213119632, + 0.1556371881989975}, + {0., + 0., + 0., + 0., + 0.120561296448323, + 0.3369911999240203, + 0.3155802843315489, + 0.1124970816045022, + 0.1189824829973467, + 0.1610633542281176} + }; - auto res4 = linalg::qr(py_a, linalg::qrmode::r); - EXPECT_TRUE(allclose(std::get<1>(res4), py_resq4_r_r)); + auto res4 = linalg::qr(py_a, linalg::qrmode::r); + CHECK(allclose(std::get<1>(res4), py_resq4_r_r)); + } } -} // namespace xt +} diff --git a/test/test_tensordot.cpp b/test/test_tensordot.cpp index 74a99c6..b367d82 100644 --- a/test/test_tensordot.cpp +++ b/test/test_tensordot.cpp @@ -12,196 +12,202 @@ #include "xtensor/views/xstrided_view.hpp" #include "xtensor/views/xview.hpp" -#include "gtest/gtest.h" +#include "doctest/doctest.h" #include "xtensor-blas/xlinalg.hpp" namespace xt { - TEST(xtensordot, outer_product) + TEST_SUITE("xtensor_dot") { - xarray a = xt::ones({3, 3, 3}); - xarray b = xt::ones({2, 2}) * 5.0; - xarray e1 = xt::ones({3, 3, 3, 2, 2}) * 5.0; - - auto r1 = linalg::tensordot(a, b, 0); - EXPECT_EQ(e1, r1); - } - - TEST(xtensordot, outer_product_cm) - { - xarray a = xt::ones({3, 3, 3}); - xarray b = xt::ones({2, 2}) * 5.0; - xarray e1 = xt::ones({3, 3, 3, 2, 2}) * 5.0; - - auto r1 = linalg::tensordot(a, b, 0); - EXPECT_EQ(e1, r1); - } - - TEST(xtensordot, outer_product_mixed_layout) - { - xarray a = xt::ones({3, 3, 3}); - xarray b = xt::ones({2, 2}) * 5.0; - xarray e1 = xt::ones({3, 3, 3, 2, 2}) * 5.0; - - auto r1 = linalg::tensordot(a, b, 0); - EXPECT_EQ(e1, r1); - - xarray e2 = xt::ones({2, 2, 3, 3, 3}) * 5.0; - auto r2 = linalg::tensordot(b, a, 0); - EXPECT_EQ(e2, r2); - } - - TEST(xtensordot, inner_product) - { - xarray a = xt::ones({3, 3, 2, 2}); - xarray b = xt::ones({2, 2, 10}); - auto r1 = linalg::tensordot(a, b); - EXPECT_TRUE(all(equal(r1, 4))); - EXPECT_TRUE(r1.shape().size() == 3); - EXPECT_TRUE(r1.shape()[0] == 3); - EXPECT_TRUE(r1.shape()[1] == 3); - EXPECT_TRUE(r1.shape()[2] == 10); - - EXPECT_THROW(linalg::tensordot(a, b, 3), std::runtime_error); - EXPECT_THROW(linalg::tensordot(b, a), std::runtime_error); - } - - TEST(xtensordot, inner_product_cm) - { - xarray a = xt::ones({3, 3, 2, 2}); - xarray b = xt::ones({2, 2, 10}); - auto r1 = linalg::tensordot(a, b); - EXPECT_TRUE(all(equal(r1, 4))); - EXPECT_TRUE(r1.shape().size() == 3); - EXPECT_TRUE(r1.shape()[0] == 3); - EXPECT_TRUE(r1.shape()[1] == 3); - EXPECT_TRUE(r1.shape()[2] == 10); - - EXPECT_THROW(linalg::tensordot(a, b, 3), std::runtime_error); - EXPECT_THROW(linalg::tensordot(b, a), std::runtime_error); - } - - TEST(xtensordot, inner_product_mixed_layout) - { - xarray a = xt::ones({3, 3, 2, 2}); - xarray b = xt::ones({3, 2, 2, 10}); - auto r1 = linalg::tensordot(a, b, 3); - EXPECT_TRUE(all(equal(r1, 12.0))); - EXPECT_TRUE(r1.shape().size() == 2); - EXPECT_TRUE(r1.shape()[0] == 3); - EXPECT_TRUE(r1.shape()[1] == 10); - - EXPECT_THROW(linalg::tensordot(b, a), std::runtime_error); - } - - TEST(xtensordot, tuple_ax) - { - xarray a = { - {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}}, - {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}}, - {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}} - }; - xarray b = xt::ones({2, 3, 2, 3}); - auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1}); - xarray e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}}; - EXPECT_EQ(r1, e1); - auto r2 = linalg::tensordot(a, b, {1, 3, 2, 0}, {0, 2, 1, 3}); - xarray e2 = xarray::from_shape({1, 1}); - e2(0, 0) = 630; - EXPECT_EQ(r2(0, 0), e2(0, 0)); - } - - TEST(xtensordot, tuple_ax_cm) - { - xarray a = { - {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}}, - {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}}, - {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}} - }; - xarray b = xt::ones({2, 3, 2, 3}); - auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1}); - xarray e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}}; - EXPECT_EQ(r1, e1); - auto r2 = linalg::tensordot(a, b, {1, 3, 2, 0}, {0, 2, 1, 3}); - xarray e2 = xarray::from_shape({1, 1}); - e2(0, 0) = 630; - EXPECT_EQ(r2(0, 0), e2(0, 0)); - } - - TEST(xtensordot, tuple_ax_mixed_layout) - { - xarray a = { - {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}}, - {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}}, - {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}} - }; - xarray b = xt::ones({2, 3, 2, 3}); - auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1}); - xarray e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}}; - EXPECT_EQ(r1, e1); - - auto r2 = linalg::tensordot(a, b, {1, 3, 2, 0}, {0, 2, 1, 3}); - xarray e2 = {630}; - - EXPECT_EQ(r2, e2); - } - - TEST(xtensordot, view) - { - xarray a = reshape_view(arange(3 * 2 * 3 * 2), {3, 2, 3, 2}); - xarray b = reshape_view(arange(3 * 3 * 2 * 2), {3, 3, 2, 2}); - - xarray e1 = {{34, 90, 146}, {46, 134, 222}, {58, 178, 298}}; - - auto res1 = linalg::tensordot( - view(a, 0, all(), all(), all()), - view(b, 0, all(), all(), all()), - {0, 2}, - {1, 2} - ); - - EXPECT_EQ(res1, e1); - EXPECT_EQ(res1.dimension(), 2u); - EXPECT_EQ(res1.shape()[0], 3u); - EXPECT_EQ(res1.shape()[1], 3u); - } - - TEST(xtensordot, strided_view_range) - { - xarray a = reshape_view(arange(3 * 2 * 3 * 2), {3, 2, 3, 2}); - xarray b = reshape_view(arange(3 * 3 * 2 * 2), {3, 3, 2, 2}); - - xarray e1 = {{1064, 1144}, {1136, 1224}}; - - auto res1 = linalg::tensordot( - strided_view(a, {range(0, 2), all(), range(0, 2), all()}), - strided_view(b, {range(0, 2), range(0, 2), all(), all()}), - {0, 1, 2}, - {0, 1, 2} - ); - EXPECT_EQ(res1, e1); - EXPECT_EQ(res1.dimension(), 2u); - EXPECT_EQ(res1.shape()[0], 2u); - EXPECT_EQ(res1.shape()[1], 2u); - } - - TEST(xtensordot, reducing_dim_view) - { - xarray a = reshape_view(arange(3 * 2 * 3 * 2), {3, 2, 3, 2}); - xarray b = reshape_view(arange(3 * 3 * 2 * 2), {3, 3, 2, 2}); - - xarray e = {1589}; - auto r = linalg::tensordot(view(a, 0, 1, all(), all()), view(b, 2, all(), 1, all())); - EXPECT_EQ(r, e); - } - - TEST(xtensordot, reducing_dim_strided_view) - { - xarray a = reshape_view(arange(3 * 2 * 3 * 2), {3, 2, 3, 2}); - xarray b = reshape_view(arange(3 * 3 * 2 * 2), {3, 3, 2, 2}); - - xarray e = {1589}; - auto r = linalg::tensordot(strided_view(a, {0, 1, all(), all()}), strided_view(b, {2, all(), 1, all()})); - EXPECT_EQ(r, e); + TEST_CASE("outer_product") + { + xarray a = xt::ones({3, 3, 3}); + xarray b = xt::ones({2, 2}) * 5.0; + xarray e1 = xt::ones({3, 3, 3, 2, 2}) * 5.0; + + auto r1 = linalg::tensordot(a, b, 0); + CHECK_EQ(e1, r1); + } + + TEST_CASE("outer_product_cm") + { + xarray a = xt::ones({3, 3, 3}); + xarray b = xt::ones({2, 2}) * 5.0; + xarray e1 = xt::ones({3, 3, 3, 2, 2}) * 5.0; + + auto r1 = linalg::tensordot(a, b, 0); + CHECK_EQ(e1, r1); + } + + TEST_CASE("outer_product_mixed_layout") + { + xarray a = xt::ones({3, 3, 3}); + xarray b = xt::ones({2, 2}) * 5.0; + xarray e1 = xt::ones({3, 3, 3, 2, 2}) * 5.0; + + auto r1 = linalg::tensordot(a, b, 0); + CHECK_EQ(e1, r1); + + xarray e2 = xt::ones({2, 2, 3, 3, 3}) * 5.0; + auto r2 = linalg::tensordot(b, a, 0); + CHECK_EQ(e2, r2); + } + + TEST_CASE("inner_product") + { + xarray a = xt::ones({3, 3, 2, 2}); + xarray b = xt::ones({2, 2, 10}); + auto r1 = linalg::tensordot(a, b); + CHECK(all(equal(r1, 4))); + CHECK(r1.shape().size() == 3); + CHECK(r1.shape()[0] == 3); + CHECK(r1.shape()[1] == 3); + CHECK(r1.shape()[2] == 10); + + CHECK_THROWS_AS(linalg::tensordot(a, b, 3), std::runtime_error); + CHECK_THROWS_AS(linalg::tensordot(b, a), std::runtime_error); + } + + TEST_CASE("inner_product_cm") + { + xarray a = xt::ones({3, 3, 2, 2}); + xarray b = xt::ones({2, 2, 10}); + auto r1 = linalg::tensordot(a, b); + CHECK(all(equal(r1, 4))); + CHECK(r1.shape().size() == 3); + CHECK(r1.shape()[0] == 3); + CHECK(r1.shape()[1] == 3); + CHECK(r1.shape()[2] == 10); + + CHECK_THROWS_AS(linalg::tensordot(a, b, 3), std::runtime_error); + CHECK_THROWS_AS(linalg::tensordot(b, a), std::runtime_error); + } + + TEST_CASE("inner_product_mixed_layout") + { + xarray a = xt::ones({3, 3, 2, 2}); + xarray b = xt::ones({3, 2, 2, 10}); + auto r1 = linalg::tensordot(a, b, 3); + CHECK(all(equal(r1, 12.0))); + CHECK(r1.shape().size() == 2); + CHECK(r1.shape()[0] == 3); + CHECK(r1.shape()[1] == 10); + + CHECK_THROWS_AS(linalg::tensordot(b, a), std::runtime_error); + } + + TEST_CASE("tuple_ax") + { + xarray a = { + {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}}, + {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}}, + {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}} + }; + xarray b = xt::ones({2, 3, 2, 3}); + auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1}); + xarray e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}}; + CHECK_EQ(r1, e1); + auto r2 = linalg::tensordot(a, b, {1, 3, 2, 0}, {0, 2, 1, 3}); + xarray e2 = xarray::from_shape({1, 1}); + e2(0, 0) = 630; + CHECK_EQ(r2(0, 0), e2(0, 0)); + } + + TEST_CASE("tuple_ax_cm") + { + xarray a = { + {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}}, + {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}}, + {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}} + }; + xarray b = xt::ones({2, 3, 2, 3}); + auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1}); + xarray e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}}; + CHECK_EQ(r1, e1); + auto r2 = linalg::tensordot(a, b, {1, 3, 2, 0}, {0, 2, 1, 3}); + xarray e2 = xarray::from_shape({1, 1}); + e2(0, 0) = 630; + CHECK_EQ(r2(0, 0), e2(0, 0)); + } + + TEST_CASE("tuple_ax_mixed_layout") + { + xarray a = { + {{{0, 1}, {2, 3}, {4, 5}}, {{6, 7}, {8, 9}, {10, 11}}}, + {{{12, 13}, {14, 15}, {16, 17}}, {{18, 19}, {20, 21}, {22, 23}}}, + {{{24, 25}, {26, 27}, {28, 29}}, {{30, 31}, {32, 33}, {34, 35}}} + }; + xarray b = xt::ones({2, 3, 2, 3}); + auto r1 = linalg::tensordot(a, b, {1, 3, 2}, {0, 2, 1}); + xarray e1 = {{66, 66, 66}, {210, 210, 210}, {354, 354, 354}}; + CHECK_EQ(r1, e1); + + auto r2 = linalg::tensordot(a, b, {1, 3, 2, 0}, {0, 2, 1, 3}); + xarray e2 = {630}; + + CHECK_EQ(r2, e2); + } + + TEST_CASE("view") + { + xarray a = reshape_view(arange(3 * 2 * 3 * 2), {3, 2, 3, 2}); + xarray b = reshape_view(arange(3 * 3 * 2 * 2), {3, 3, 2, 2}); + + xarray e1 = {{34, 90, 146}, {46, 134, 222}, {58, 178, 298}}; + + auto res1 = linalg::tensordot( + view(a, 0, all(), all(), all()), + view(b, 0, all(), all(), all()), + {0, 2}, + {1, 2} + ); + + CHECK_EQ(res1, e1); + CHECK_EQ(res1.dimension(), 2u); + CHECK_EQ(res1.shape()[0], 3u); + CHECK_EQ(res1.shape()[1], 3u); + } + + TEST_CASE("strided_view_range") + { + xarray a = reshape_view(arange(3 * 2 * 3 * 2), {3, 2, 3, 2}); + xarray b = reshape_view(arange(3 * 3 * 2 * 2), {3, 3, 2, 2}); + + xarray e1 = {{1064, 1144}, {1136, 1224}}; + + auto res1 = linalg::tensordot( + strided_view(a, {range(0, 2), all(), range(0, 2), all()}), + strided_view(b, {range(0, 2), range(0, 2), all(), all()}), + {0, 1, 2}, + {0, 1, 2} + ); + CHECK_EQ(res1, e1); + CHECK_EQ(res1.dimension(), 2u); + CHECK_EQ(res1.shape()[0], 2u); + CHECK_EQ(res1.shape()[1], 2u); + } + + TEST_CASE("reducing_dim_view") + { + xarray a = reshape_view(arange(3 * 2 * 3 * 2), {3, 2, 3, 2}); + xarray b = reshape_view(arange(3 * 3 * 2 * 2), {3, 3, 2, 2}); + + xarray e = {1589}; + auto r = linalg::tensordot(view(a, 0, 1, all(), all()), view(b, 2, all(), 1, all())); + CHECK_EQ(r, e); + } + + TEST_CASE("reducing_dim_strided_view") + { + xarray a = reshape_view(arange(3 * 2 * 3 * 2), {3, 2, 3, 2}); + xarray b = reshape_view(arange(3 * 3 * 2 * 2), {3, 3, 2, 2}); + + xarray e = {1589}; + auto r = linalg::tensordot( + strided_view(a, {0, 1, all(), all()}), + strided_view(b, {2, all(), 1, all()}) + ); + CHECK_EQ(r, e); + } } } // namespace xt