Skip to content

Commit 4a0c0e8

Browse files
authored
Merge pull request #3298 from stan-dev/feature/row-col-stochastic
row/col stochastic matrix for deserializer and serializer
2 parents 7e76f26 + 70e9ec3 commit 4a0c0e8

File tree

2 files changed

+518
-529
lines changed

2 files changed

+518
-529
lines changed

src/stan/io/deserializer.hpp

Lines changed: 208 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -382,11 +382,8 @@ class deserializer {
382382
template <typename Ret, bool Jacobian, typename LB, typename LP,
383383
typename... Sizes>
384384
inline auto read_constrain_lb(const LB& lb, LP& lp, Sizes... sizes) {
385-
if (Jacobian) {
386-
return stan::math::lb_constrain(this->read<Ret>(sizes...), lb, lp);
387-
} else {
388-
return stan::math::lb_constrain(this->read<Ret>(sizes...), lb);
389-
}
385+
return stan::math::lb_constrain<Jacobian>(this->read<Ret>(sizes...), lb,
386+
lp);
390387
}
391388

392389
/**
@@ -408,11 +405,8 @@ class deserializer {
408405
template <typename Ret, bool Jacobian, typename UB, typename LP,
409406
typename... Sizes>
410407
inline auto read_constrain_ub(const UB& ub, LP& lp, Sizes... sizes) {
411-
if (Jacobian) {
412-
return stan::math::ub_constrain(this->read<Ret>(sizes...), ub, lp);
413-
} else {
414-
return stan::math::ub_constrain(this->read<Ret>(sizes...), ub);
415-
}
408+
return stan::math::ub_constrain<Jacobian>(this->read<Ret>(sizes...), ub,
409+
lp);
416410
}
417411

418412
/**
@@ -437,11 +431,8 @@ class deserializer {
437431
typename... Sizes>
438432
inline auto read_constrain_lub(const LB& lb, const UB& ub, LP& lp,
439433
Sizes... sizes) {
440-
if (Jacobian) {
441-
return stan::math::lub_constrain(this->read<Ret>(sizes...), lb, ub, lp);
442-
} else {
443-
return stan::math::lub_constrain(this->read<Ret>(sizes...), lb, ub);
444-
}
434+
return stan::math::lub_constrain<Jacobian>(this->read<Ret>(sizes...), lb,
435+
ub, lp);
445436
}
446437

447438
/**
@@ -470,14 +461,8 @@ class deserializer {
470461
inline auto read_constrain_offset_multiplier(const Offset& offset,
471462
const Mult& multiplier, LP& lp,
472463
Sizes... sizes) {
473-
using stan::math::offset_multiplier_constrain;
474-
if (Jacobian) {
475-
return offset_multiplier_constrain(this->read<Ret>(sizes...), offset,
476-
multiplier, lp);
477-
} else {
478-
return offset_multiplier_constrain(this->read<Ret>(sizes...), offset,
479-
multiplier);
480-
}
464+
return stan::math::offset_multiplier_constrain<Jacobian>(
465+
this->read<Ret>(sizes...), offset, multiplier, lp);
481466
}
482467

483468
/**
@@ -501,12 +486,8 @@ class deserializer {
501486
template <typename Ret, bool Jacobian, typename LP, typename... Sizes,
502487
require_not_std_vector_t<Ret>* = nullptr>
503488
inline auto read_constrain_unit_vector(LP& lp, Sizes... sizes) {
504-
using stan::math::unit_vector_constrain;
505-
if (Jacobian) {
506-
return math::eval(unit_vector_constrain(this->read<Ret>(sizes...), lp));
507-
} else {
508-
return math::eval(unit_vector_constrain(this->read<Ret>(sizes...)));
509-
}
489+
return stan::math::eval(stan::math::unit_vector_constrain<Jacobian>(
490+
this->read<Ret>(sizes...), lp));
510491
}
511492

512493
/**
@@ -562,13 +543,9 @@ class deserializer {
562543
template <typename Ret, bool Jacobian, typename LP,
563544
require_not_std_vector_t<Ret>* = nullptr>
564545
inline auto read_constrain_simplex(LP& lp, size_t size) {
565-
using stan::math::simplex_constrain;
566546
stan::math::check_positive("read_simplex", "size", size);
567-
if (Jacobian) {
568-
return simplex_constrain(this->read<Ret>(size - 1), lp);
569-
} else {
570-
return simplex_constrain(this->read<Ret>(size - 1));
571-
}
547+
return stan::math::simplex_constrain<Jacobian>(this->read<Ret>(size - 1),
548+
lp);
572549
}
573550

574551
/**
@@ -624,12 +601,8 @@ class deserializer {
624601
template <typename Ret, bool Jacobian, typename LP, typename... Sizes,
625602
require_not_std_vector_t<Ret>* = nullptr>
626603
inline auto read_constrain_ordered(LP& lp, Sizes... sizes) {
627-
using stan::math::ordered_constrain;
628-
if (Jacobian) {
629-
return ordered_constrain(this->read<Ret>(sizes...), lp);
630-
} else {
631-
return ordered_constrain(this->read<Ret>(sizes...));
632-
}
604+
return stan::math::ordered_constrain<Jacobian>(this->read<Ret>(sizes...),
605+
lp);
633606
}
634607

635608
/**
@@ -684,12 +657,8 @@ class deserializer {
684657
template <typename Ret, bool Jacobian, typename LP, typename... Sizes,
685658
require_not_std_vector_t<Ret>* = nullptr>
686659
inline auto read_constrain_positive_ordered(LP& lp, Sizes... sizes) {
687-
using stan::math::positive_ordered_constrain;
688-
if (Jacobian) {
689-
return positive_ordered_constrain(this->read<Ret>(sizes...), lp);
690-
} else {
691-
return positive_ordered_constrain(this->read<Ret>(sizes...));
692-
}
660+
return stan::math::positive_ordered_constrain<Jacobian>(
661+
this->read<Ret>(sizes...), lp);
693662
}
694663

695664
/**
@@ -745,17 +714,10 @@ class deserializer {
745714
require_matrix_t<Ret>* = nullptr>
746715
inline auto read_constrain_cholesky_factor_cov(LP& lp, Eigen::Index M,
747716
Eigen::Index N) {
748-
if (Jacobian) {
749-
return stan::math::cholesky_factor_constrain(
750-
this->read<conditional_var_val_t<Ret, vector_t>>((N * (N + 1)) / 2
751-
+ (M - N) * N),
752-
M, N, lp);
753-
} else {
754-
return stan::math::cholesky_factor_constrain(
755-
this->read<conditional_var_val_t<Ret, vector_t>>((N * (N + 1)) / 2
756-
+ (M - N) * N),
757-
M, N);
758-
}
717+
return stan::math::cholesky_factor_constrain<Jacobian>(
718+
this->read<conditional_var_val_t<Ret, vector_t>>((N * (N + 1)) / 2
719+
+ (M - N) * N),
720+
M, N, lp);
759721
}
760722

761723
/**
@@ -811,16 +773,9 @@ class deserializer {
811773
template <typename Ret, bool Jacobian, typename LP,
812774
require_matrix_t<Ret>* = nullptr>
813775
inline auto read_constrain_cholesky_factor_corr(LP& lp, Eigen::Index K) {
814-
using stan::math::cholesky_corr_constrain;
815-
if (Jacobian) {
816-
return cholesky_corr_constrain(
817-
this->read<conditional_var_val_t<Ret, vector_t>>((K * (K - 1)) / 2),
818-
K, lp);
819-
} else {
820-
return cholesky_corr_constrain(
821-
this->read<conditional_var_val_t<Ret, vector_t>>((K * (K - 1)) / 2),
822-
K);
823-
}
776+
return stan::math::cholesky_corr_constrain<Jacobian>(
777+
this->read<conditional_var_val_t<Ret, vector_t>>((K * (K - 1)) / 2), K,
778+
lp);
824779
}
825780

826781
/**
@@ -875,18 +830,9 @@ class deserializer {
875830
template <typename Ret, bool Jacobian, typename LP,
876831
require_matrix_t<Ret>* = nullptr>
877832
inline auto read_constrain_cov_matrix(LP& lp, Eigen::Index k) {
878-
using stan::math::cov_matrix_constrain;
879-
if (Jacobian) {
880-
return cov_matrix_constrain(
881-
this->read<conditional_var_val_t<Ret, vector_t>>(k
882-
+ (k * (k - 1)) / 2),
883-
k, lp);
884-
} else {
885-
return cov_matrix_constrain(
886-
this->read<conditional_var_val_t<Ret, vector_t>>(k
887-
+ (k * (k - 1)) / 2),
888-
k);
889-
}
833+
return stan::math::cov_matrix_constrain<Jacobian>(
834+
this->read<conditional_var_val_t<Ret, vector_t>>(k + (k * (k - 1)) / 2),
835+
k, lp);
890836
}
891837

892838
/**
@@ -939,16 +885,9 @@ class deserializer {
939885
require_not_std_vector_t<Ret>* = nullptr,
940886
require_matrix_t<Ret>* = nullptr>
941887
inline auto read_constrain_corr_matrix(LP& lp, Eigen::Index k) {
942-
using stan::math::corr_matrix_constrain;
943-
if (Jacobian) {
944-
return corr_matrix_constrain(
945-
this->read<conditional_var_val_t<Ret, vector_t>>((k * (k - 1)) / 2),
946-
k, lp);
947-
} else {
948-
return corr_matrix_constrain(
949-
this->read<conditional_var_val_t<Ret, vector_t>>((k * (k - 1)) / 2),
950-
k);
951-
}
888+
return stan::math::corr_matrix_constrain<Jacobian>(
889+
this->read<conditional_var_val_t<Ret, vector_t>>((k * (k - 1)) / 2), k,
890+
lp);
952891
}
953892

954893
/**
@@ -981,6 +920,118 @@ class deserializer {
981920
return ret;
982921
}
983922

923+
/**
924+
* Return the next object transformed to a matrix with simplexes along the
925+
* columns
926+
*
927+
* <p>See <code>stan::math::stochastic_column_constrain(T,T&)</code>.
928+
*
929+
* @tparam Ret The type to return.
930+
* @tparam Jacobian Whether to increment the log of the absolute Jacobian
931+
* determinant of the transform.
932+
* @tparam LP Type of log probability.
933+
* @param lp The reference to the variable holding the log
934+
* probability to increment.
935+
* @param rows Rows of matrix
936+
* @param cows Cows of matrix
937+
*/
938+
template <typename Ret, bool Jacobian, typename LP,
939+
require_not_std_vector_t<Ret>* = nullptr,
940+
require_matrix_t<Ret>* = nullptr>
941+
inline auto read_constrain_stochastic_column(LP& lp, Eigen::Index rows,
942+
Eigen::Index cols) {
943+
return stan::math::stochastic_column_constrain<Jacobian>(
944+
this->read<conditional_var_val_t<Ret, matrix_t>>(rows - 1, cols), lp);
945+
}
946+
947+
/**
948+
* Specialization of \ref read_constrain_stochastic_column for `std::vector`
949+
* return types.
950+
*
951+
* <p>See <code>stan::math::stochastic_column_constrain(T,T&)</code>.
952+
*
953+
* @tparam Ret The type to return.
954+
* @tparam Jacobian Whether to increment the log of the absolute Jacobian
955+
* determinant of the transform.
956+
* @tparam LP Type of log probability.
957+
* @tparam Sizes A parameter pack of integral types.
958+
* @param lp The reference to the variable holding the log
959+
* probability to increment.
960+
* @param vecsize The size of the return vector.
961+
* @param sizes Pack of integrals to use to construct the return's type.
962+
* @return Standard vector of matrices transformed to have simplixes along the
963+
* columns.
964+
*/
965+
template <typename Ret, bool Jacobian, typename LP, typename... Sizes,
966+
require_std_vector_t<Ret>* = nullptr>
967+
inline auto read_constrain_stochastic_column(LP& lp, const size_t vecsize,
968+
Sizes... sizes) {
969+
std::decay_t<Ret> ret;
970+
ret.reserve(vecsize);
971+
for (size_t i = 0; i < vecsize; ++i) {
972+
ret.emplace_back(
973+
this->read_constrain_stochastic_column<value_type_t<Ret>, Jacobian>(
974+
lp, sizes...));
975+
}
976+
return ret;
977+
}
978+
979+
/**
980+
* Return the next object transformed to a matrix with simplexes along the
981+
* rows
982+
*
983+
* <p>See <code>stan::math::stochastic_row_constrain(T,T&)</code>.
984+
*
985+
* @tparam Ret The type to return.
986+
* @tparam Jacobian Whether to increment the log of the absolute Jacobian
987+
* determinant of the transform.
988+
* @tparam LP Type of log probability.
989+
* @param lp The reference to the variable holding the log
990+
* probability to increment.
991+
* @param rows Rows of matrix
992+
* @param cows Cows of matrix
993+
*/
994+
template <typename Ret, bool Jacobian, typename LP,
995+
require_not_std_vector_t<Ret>* = nullptr,
996+
require_matrix_t<Ret>* = nullptr>
997+
inline auto read_constrain_stochastic_row(LP& lp, Eigen::Index rows,
998+
Eigen::Index cols) {
999+
return stan::math::stochastic_row_constrain<Jacobian>(
1000+
this->read<conditional_var_val_t<Ret, matrix_t>>(rows, cols - 1), lp);
1001+
}
1002+
1003+
/**
1004+
* Specialization of \ref read_constrain_stochastic_row for `std::vector`
1005+
* return types.
1006+
*
1007+
* <p>See <code>stan::math::stochastic_row_constrain(T,T&)</code>.
1008+
*
1009+
* @tparam Ret The type to return.
1010+
* @tparam Jacobian Whether to increment the log of the absolute Jacobian
1011+
* determinant of the transform.
1012+
* @tparam LP Type of log probability.
1013+
* @tparam Sizes A parameter pack of integral types.
1014+
* @param lp The reference to the variable holding the log
1015+
* probability to increment.
1016+
* @param vecsize The size of the return vector.
1017+
* @param sizes Pack of integrals to use to construct the return's type.
1018+
* @return Standard vector of matrices transformed to have simplixes along the
1019+
* columns.
1020+
*/
1021+
template <typename Ret, bool Jacobian, typename LP, typename... Sizes,
1022+
require_std_vector_t<Ret>* = nullptr>
1023+
inline auto read_constrain_stochastic_row(LP& lp, const size_t vecsize,
1024+
Sizes... sizes) {
1025+
std::decay_t<Ret> ret;
1026+
ret.reserve(vecsize);
1027+
for (size_t i = 0; i < vecsize; ++i) {
1028+
ret.emplace_back(
1029+
this->read_constrain_stochastic_row<value_type_t<Ret>, Jacobian>(
1030+
lp, sizes...));
1031+
}
1032+
return ret;
1033+
}
1034+
9841035
/**
9851036
* Read a serialized lower bounded variable and unconstrain it
9861037
*
@@ -1301,6 +1352,73 @@ class deserializer {
13011352
}
13021353
return ret;
13031354
}
1355+
1356+
/**
1357+
* Read a serialized column simplex matrix and unconstrain it
1358+
*
1359+
* @tparam Ret Type of output
1360+
* @param rows Rows of matrix
1361+
* @param cows Cows of matrix
1362+
* @return Unconstrained matrix
1363+
*/
1364+
template <typename Ret, require_not_std_vector_t<Ret>* = nullptr>
1365+
inline auto read_free_stochastic_column(size_t rows, size_t cols) {
1366+
return stan::math::stochastic_column_free(this->read<Ret>(rows, cols));
1367+
}
1368+
1369+
/**
1370+
* Read serialized column simplex matrices and unconstrain them
1371+
*
1372+
* @tparam Ret Type of output
1373+
* @tparam Sizes Types of dimensions of output
1374+
* @param vecsize Vector size
1375+
* @param sizes dimensions
1376+
* @return Unconstrained matrices
1377+
*/
1378+
template <typename Ret, typename... Sizes,
1379+
require_std_vector_t<Ret>* = nullptr>
1380+
inline auto read_free_stochastic_column(size_t vecsize, Sizes... sizes) {
1381+
std::decay_t<Ret> ret;
1382+
ret.reserve(vecsize);
1383+
for (size_t i = 0; i < vecsize; ++i) {
1384+
ret.emplace_back(
1385+
read_free_stochastic_column<value_type_t<Ret>>(sizes...));
1386+
}
1387+
return ret;
1388+
}
1389+
1390+
/**
1391+
* Read a serialized row simplex matrix and unconstrain it
1392+
*
1393+
* @tparam Ret Type of output
1394+
* @param rows Rows of matrix
1395+
* @param cows Cows of matrix
1396+
* @return Unconstrained matrix
1397+
*/
1398+
template <typename Ret, require_not_std_vector_t<Ret>* = nullptr>
1399+
inline auto read_free_stochastic_row(size_t rows, size_t cols) {
1400+
return stan::math::stochastic_row_free(this->read<Ret>(rows, cols));
1401+
}
1402+
1403+
/**
1404+
* Read serialized row simplex matrices and unconstrain them
1405+
*
1406+
* @tparam Ret Type of output
1407+
* @tparam Sizes Types of dimensions of output
1408+
* @param vecsize Vector size
1409+
* @param sizes dimensions
1410+
* @return Unconstrained matrices
1411+
*/
1412+
template <typename Ret, typename... Sizes,
1413+
require_std_vector_t<Ret>* = nullptr>
1414+
inline auto read_free_stochastic_row(size_t vecsize, Sizes... sizes) {
1415+
std::decay_t<Ret> ret;
1416+
ret.reserve(vecsize);
1417+
for (size_t i = 0; i < vecsize; ++i) {
1418+
ret.emplace_back(read_free_stochastic_row<value_type_t<Ret>>(sizes...));
1419+
}
1420+
return ret;
1421+
}
13041422
};
13051423

13061424
} // namespace io

0 commit comments

Comments
 (0)