Skip to content

Commit 3d8d669

Browse files
authored
Merge pull request #3217 from stan-dev/fix/3216-dot-product-checks
Fix rows/columns_dot_product not properly checking inputs
2 parents f7ccc01 + 28337f4 commit 3d8d669

File tree

10 files changed

+14
-12
lines changed

10 files changed

+14
-12
lines changed

stan/math/opencl/prim/columns_dot_product.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <stan/math/opencl/prim/sum.hpp>
66
#include <stan/math/prim/meta.hpp>
77
#include <stan/math/prim/err/check_vector.hpp>
8-
#include <stan/math/prim/err/check_matching_sizes.hpp>
8+
#include <stan/math/prim/err/check_matching_dims.hpp>
99

1010
namespace stan {
1111
namespace math {
@@ -26,7 +26,7 @@ template <typename T_a, typename T_b,
2626
require_all_kernel_expressions_and_none_scalar_t<T_a, T_b>* = nullptr>
2727
inline auto columns_dot_product(const T_a& a, const T_b& b) {
2828
using res_scal = std::common_type_t<value_type_t<T_a>, value_type_t<T_b>>;
29-
check_matching_sizes("columns_dot_product", "a", a, "b", b);
29+
check_matching_dims("columns_dot_product", "a", a, "b", b);
3030
matrix_cl<res_scal> res;
3131

3232
if (size_zero(a, b)) {

stan/math/opencl/prim/rows_dot_product.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
#include <stan/math/opencl/prim/sum.hpp>
66
#include <stan/math/prim/meta.hpp>
77
#include <stan/math/prim/err/check_vector.hpp>
8-
#include <stan/math/prim/err/check_matching_sizes.hpp>
8+
#include <stan/math/prim/err/check_matching_dims.hpp>
99

1010
namespace stan {
1111
namespace math {
@@ -25,7 +25,7 @@ namespace math {
2525
template <typename T_a, typename T_b,
2626
require_all_kernel_expressions_and_none_scalar_t<T_a, T_b>* = nullptr>
2727
inline auto rows_dot_product(T_a&& a, T_b&& b) {
28-
check_matching_sizes("rows_dot_product", "a", a, "b", b);
28+
check_matching_dims("rows_dot_product", "a", a, "b", b);
2929
return rowwise_sum(elt_multiply(std::forward<T_a>(a), std::forward<T_b>(b)));
3030
}
3131

stan/math/opencl/rev/columns_dot_product.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ template <
2828
typename T1, typename T2, require_any_var_t<T1, T2>* = nullptr,
2929
require_all_nonscalar_prim_or_rev_kernel_expression_t<T1, T2>* = nullptr>
3030
inline var_value<matrix_cl<double>> columns_dot_product(T1&& v1, T2&& v2) {
31-
check_matching_sizes("columns_dot_product(OpenCL)", "v1", v1, "v2", v2);
31+
check_matching_dims("columns_dot_product(OpenCL)", "v1", v1, "v2", v2);
3232

3333
if (size_zero(v1, v2)) {
3434
return var_value<matrix_cl<double>>(constant(0.0, 1, v1.cols()));

stan/math/opencl/rev/rows_dot_product.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ template <
2828
typename T1, typename T2, require_any_var_t<T1, T2>* = nullptr,
2929
require_all_nonscalar_prim_or_rev_kernel_expression_t<T1, T2>* = nullptr>
3030
inline var_value<matrix_cl<double>> rows_dot_product(T1&& v1, T2&& v2) {
31-
check_matching_sizes("rows_dot_product(OpenCL)", "v1", v1, "v2", v2);
31+
check_matching_dims("rows_dot_product(OpenCL)", "v1", v1, "v2", v2);
3232

3333
arena_t<T1> v1_arena = std::forward<T1>(v1);
3434
arena_t<T2> v2_arena = std::forward<T2>(v2);

stan/math/prim/fun/columns_dot_product.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ template <typename Mat1, typename Mat2,
2626
require_all_not_eigen_vt<is_var, Mat1, Mat2>* = nullptr>
2727
inline Eigen::Matrix<return_type_t<Mat1, Mat2>, 1, Mat1::ColsAtCompileTime>
2828
columns_dot_product(const Mat1& v1, const Mat2& v2) {
29-
check_matching_sizes("columns_dot_product", "v1", v1, "v2", v2);
29+
check_matching_dims("columns_dot_product", "v1", v1, "v2", v2);
3030
return v1.cwiseProduct(v2).colwise().sum();
3131
}
3232

stan/math/prim/fun/rows_dot_product.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ template <typename Mat1, typename Mat2,
2626
require_all_not_eigen_vt<is_var, Mat1, Mat2>* = nullptr>
2727
inline Eigen::Matrix<return_type_t<Mat1, Mat2>, Mat1::RowsAtCompileTime, 1>
2828
rows_dot_product(const Mat1& v1, const Mat2& v2) {
29-
check_matching_sizes("rows_dot_product", "v1", v1, "v2", v2);
29+
check_matching_dims("rows_dot_product", "v1", v1, "v2", v2);
3030
return (v1.cwiseProduct(v2)).rowwise().sum();
3131
}
3232

stan/math/rev/fun/columns_dot_product.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ template <typename Mat1, typename Mat2,
3333
require_any_eigen_vt<is_var, Mat1, Mat2>* = nullptr>
3434
inline Eigen::Matrix<return_type_t<Mat1, Mat2>, 1, Mat1::ColsAtCompileTime>
3535
columns_dot_product(const Mat1& v1, const Mat2& v2) {
36-
check_matching_sizes("dot_product", "v1", v1, "v2", v2);
36+
check_matching_dims("check_matching_dims", "v1", v1, "v2", v2);
3737
Eigen::Matrix<var, 1, Mat1::ColsAtCompileTime> ret(1, v1.cols());
3838
for (size_type j = 0; j < v1.cols(); ++j) {
3939
ret.coeffRef(j) = dot_product(v1.col(j), v2.col(j));
@@ -62,7 +62,7 @@ template <typename Mat1, typename Mat2,
6262
require_all_matrix_t<Mat1, Mat2>* = nullptr,
6363
require_any_var_matrix_t<Mat1, Mat2>* = nullptr>
6464
inline auto columns_dot_product(const Mat1& v1, const Mat2& v2) {
65-
check_matching_sizes("columns_dot_product", "v1", v1, "v2", v2);
65+
check_matching_dims("columns_dot_product", "v1", v1, "v2", v2);
6666
using inner_return_t = decltype(
6767
(value_of(v1).array() * value_of(v2).array()).colwise().sum().matrix());
6868
using return_t = return_var_matrix_t<inner_return_t, Mat1, Mat2>;

stan/math/rev/fun/rows_dot_product.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ template <typename Mat1, typename Mat2,
3232
require_any_eigen_vt<is_var, Mat1, Mat2>* = nullptr>
3333
inline Eigen::Matrix<var, Mat1::RowsAtCompileTime, 1> rows_dot_product(
3434
const Mat1& v1, const Mat2& v2) {
35-
check_matching_sizes("dot_product", "v1", v1, "v2", v2);
35+
check_matching_dims("rows_dot_product", "v1", v1, "v2", v2);
3636
Eigen::Matrix<var, Mat1::RowsAtCompileTime, 1> ret(v1.rows(), 1);
3737
for (size_type j = 0; j < v1.rows(); ++j) {
3838
ret.coeffRef(j) = dot_product(v1.row(j), v2.row(j));
@@ -61,7 +61,7 @@ template <typename Mat1, typename Mat2,
6161
require_all_matrix_t<Mat1, Mat2>* = nullptr,
6262
require_any_var_matrix_t<Mat1, Mat2>* = nullptr>
6363
inline auto rows_dot_product(const Mat1& v1, const Mat2& v2) {
64-
check_matching_sizes("rows_dot_product", "v1", v1, "v2", v2);
64+
check_matching_dims("rows_dot_product", "v1", v1, "v2", v2);
6565

6666
using return_t = return_var_matrix_t<
6767
decltype((v1.val().array() * v2.val().array()).rowwise().sum().matrix()),

test/unit/math/mix/fun/columns_dot_product_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ TEST(MathMixMatFun, columnsDotProduct) {
6161
stan::test::expect_ad(f, erv2, erv3);
6262
stan::test::expect_ad(f, em33, em23);
6363
stan::test::expect_ad(f, em23, em33);
64+
stan::test::expect_ad(f, em23, em32);
6465
stan::test::expect_ad_matvar(f, ev2, ev3);
6566
stan::test::expect_ad_matvar(f, erv2, erv3);
6667
stan::test::expect_ad_matvar(f, em33, em23);

test/unit/math/mix/fun/rows_dot_product_test.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ TEST(MathMixMatFun, rowsDotProduct) {
6161
stan::test::expect_ad(f, erv2, erv3);
6262
stan::test::expect_ad(f, em33, em23);
6363
stan::test::expect_ad(f, em23, em33);
64+
stan::test::expect_ad(f, em23, em32);
6465
stan::test::expect_ad_matvar(f, ev2, ev3);
6566
stan::test::expect_ad_matvar(f, erv2, erv3);
6667
stan::test::expect_ad_matvar(f, em33, em23);

0 commit comments

Comments
 (0)