Skip to content

Commit 5cf7bb8

Browse files
committed
adds row and col stochastic constraints to deserializer
1 parent 1b09798 commit 5cf7bb8

File tree

2 files changed

+454
-439
lines changed

2 files changed

+454
-439
lines changed

src/stan/io/deserializer.hpp

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -981,6 +981,123 @@ class deserializer {
981981
return ret;
982982
}
983983

984+
/**
985+
* Return the next object transformed to a matrix with simplexes along the columns
986+
*
987+
* <p>See <code>stan::math::stochastic_column_constrain(T,T&)</code>.
988+
*
989+
* @tparam Ret The type to return.
990+
* @tparam Jacobian Whether to increment the log of the absolute Jacobian
991+
* determinant of the transform.
992+
* @tparam LP Type of log probability.
993+
* @param lp The reference to the variable holding the log
994+
* probability to increment.
995+
* @param rows Rows of matrix
996+
* @param cows Cows of matrix
997+
*/
998+
template <typename Ret, bool Jacobian, typename LP,
999+
require_not_std_vector_t<Ret>* = nullptr,
1000+
require_matrix_t<Ret>* = nullptr>
1001+
inline auto read_constrain_stochastic_column(LP& lp, Eigen::Index rows, Eigen::Index cols) {
1002+
using stan::math::stochastic_column_constrain;
1003+
if (Jacobian) {
1004+
return stochastic_column_constrain(
1005+
this->read<conditional_var_val_t<Ret, matrix_t>>(rows - 1, cols), lp);
1006+
} else {
1007+
return stochastic_column_constrain(
1008+
this->read<conditional_var_val_t<Ret, matrix_t>>(rows - 1, cols));
1009+
}
1010+
}
1011+
1012+
/**
1013+
* Specialization of \ref read_constrain_stochastic_column for `std::vector` return types.
1014+
*
1015+
* <p>See <code>stan::math::stochastic_column_constrain(T,T&)</code>.
1016+
*
1017+
* @tparam Ret The type to return.
1018+
* @tparam Jacobian Whether to increment the log of the absolute Jacobian
1019+
* determinant of the transform.
1020+
* @tparam LP Type of log probability.
1021+
* @tparam Sizes A parameter pack of integral types.
1022+
* @param lp The reference to the variable holding the log
1023+
* probability to increment.
1024+
* @param vecsize The size of the return vector.
1025+
* @param sizes Pack of integrals to use to construct the return's type.
1026+
* @return Standard vector of matrices transformed to have simplixes along the columns.
1027+
*/
1028+
template <typename Ret, bool Jacobian, typename LP, typename... Sizes,
1029+
require_std_vector_t<Ret>* = nullptr>
1030+
inline auto read_constrain_stochastic_column(LP& lp, const size_t vecsize,
1031+
Sizes... sizes) {
1032+
std::decay_t<Ret> ret;
1033+
ret.reserve(vecsize);
1034+
for (size_t i = 0; i < vecsize; ++i) {
1035+
ret.emplace_back(
1036+
this->read_constrain_stochastic_column<value_type_t<Ret>, Jacobian>(
1037+
lp, sizes...));
1038+
}
1039+
return ret;
1040+
}
1041+
1042+
/**
1043+
* Return the next object transformed to a matrix with simplexes along the rows
1044+
*
1045+
* <p>See <code>stan::math::stochastic_row_constrain(T,T&)</code>.
1046+
*
1047+
* @tparam Ret The type to return.
1048+
* @tparam Jacobian Whether to increment the log of the absolute Jacobian
1049+
* determinant of the transform.
1050+
* @tparam LP Type of log probability.
1051+
* @param lp The reference to the variable holding the log
1052+
* probability to increment.
1053+
* @param rows Rows of matrix
1054+
* @param cows Cows of matrix
1055+
*/
1056+
template <typename Ret, bool Jacobian, typename LP,
1057+
require_not_std_vector_t<Ret>* = nullptr,
1058+
require_matrix_t<Ret>* = nullptr>
1059+
inline auto read_constrain_stochastic_row(LP& lp, Eigen::Index rows, Eigen::Index cols) {
1060+
using stan::math::stochastic_row_constrain;
1061+
if (Jacobian) {
1062+
return stochastic_row_constrain(
1063+
this->read<conditional_var_val_t<Ret, matrix_t>>(rows, cols - 1), lp);
1064+
} else {
1065+
return stochastic_row_constrain(
1066+
this->read<conditional_var_val_t<Ret, matrix_t>>(rows, cols - 1));
1067+
}
1068+
}
1069+
1070+
/**
1071+
* Specialization of \ref read_constrain_stochastic_row for `std::vector` return types.
1072+
*
1073+
* <p>See <code>stan::math::stochastic_row_constrain(T,T&)</code>.
1074+
*
1075+
* @tparam Ret The type to return.
1076+
* @tparam Jacobian Whether to increment the log of the absolute Jacobian
1077+
* determinant of the transform.
1078+
* @tparam LP Type of log probability.
1079+
* @tparam Sizes A parameter pack of integral types.
1080+
* @param lp The reference to the variable holding the log
1081+
* probability to increment.
1082+
* @param vecsize The size of the return vector.
1083+
* @param sizes Pack of integrals to use to construct the return's type.
1084+
* @return Standard vector of matrices transformed to have simplixes along the columns.
1085+
*/
1086+
template <typename Ret, bool Jacobian, typename LP, typename... Sizes,
1087+
require_std_vector_t<Ret>* = nullptr>
1088+
inline auto read_constrain_stochastic_row(LP& lp, const size_t vecsize,
1089+
Sizes... sizes) {
1090+
std::decay_t<Ret> ret;
1091+
ret.reserve(vecsize);
1092+
for (size_t i = 0; i < vecsize; ++i) {
1093+
ret.emplace_back(
1094+
this->read_constrain_stochastic_row<value_type_t<Ret>, Jacobian>(
1095+
lp, sizes...));
1096+
}
1097+
return ret;
1098+
}
1099+
1100+
9841101
/**
9851102
* Read a serialized lower bounded variable and unconstrain it
9861103
*
@@ -1301,6 +1418,73 @@ class deserializer {
13011418
}
13021419
return ret;
13031420
}
1421+
1422+
/**
1423+
* Read a serialized column simplex matrix and unconstrain it
1424+
*
1425+
* @tparam Ret Type of output
1426+
* @param rows Rows of matrix
1427+
* @param cows Cows of matrix
1428+
* @return Unconstrained matrix
1429+
*/
1430+
template <typename Ret, require_not_std_vector_t<Ret>* = nullptr>
1431+
inline auto read_free_stochastic_column(size_t rows, size_t cols) {
1432+
return stan::math::stochastic_column_free(this->read<Ret>(rows, cols));
1433+
}
1434+
1435+
/**
1436+
* Read serialized column simplex matrices and unconstrain them
1437+
*
1438+
* @tparam Ret Type of output
1439+
* @tparam Sizes Types of dimensions of output
1440+
* @param vecsize Vector size
1441+
* @param sizes dimensions
1442+
* @return Unconstrained matrices
1443+
*/
1444+
template <typename Ret, typename... Sizes,
1445+
require_std_vector_t<Ret>* = nullptr>
1446+
inline auto read_free_stochastic_column(size_t vecsize, Sizes... sizes) {
1447+
std::decay_t<Ret> ret;
1448+
ret.reserve(vecsize);
1449+
for (size_t i = 0; i < vecsize; ++i) {
1450+
ret.emplace_back(read_free_stochastic_column<value_type_t<Ret>>(sizes...));
1451+
}
1452+
return ret;
1453+
}
1454+
1455+
/**
1456+
* Read a serialized row simplex matrix and unconstrain it
1457+
*
1458+
* @tparam Ret Type of output
1459+
* @param rows Rows of matrix
1460+
* @param cows Cows of matrix
1461+
* @return Unconstrained matrix
1462+
*/
1463+
template <typename Ret, require_not_std_vector_t<Ret>* = nullptr>
1464+
inline auto read_free_stochastic_row(size_t rows, size_t cols) {
1465+
return stan::math::stochastic_row_free(this->read<Ret>(rows, cols));
1466+
}
1467+
1468+
/**
1469+
* Read serialized row simplex matrices and unconstrain them
1470+
*
1471+
* @tparam Ret Type of output
1472+
* @tparam Sizes Types of dimensions of output
1473+
* @param vecsize Vector size
1474+
* @param sizes dimensions
1475+
* @return Unconstrained matrices
1476+
*/
1477+
template <typename Ret, typename... Sizes,
1478+
require_std_vector_t<Ret>* = nullptr>
1479+
inline auto read_free_stochastic_row(size_t vecsize, Sizes... sizes) {
1480+
std::decay_t<Ret> ret;
1481+
ret.reserve(vecsize);
1482+
for (size_t i = 0; i < vecsize; ++i) {
1483+
ret.emplace_back(read_free_stochastic_row<value_type_t<Ret>>(sizes...));
1484+
}
1485+
return ret;
1486+
}
1487+
13041488
};
13051489

13061490
} // namespace io

0 commit comments

Comments
 (0)