@@ -1062,6 +1062,48 @@ CUDA_HOST_DEVICE void hypot_pullback(T x, T y, U d_z, T* d_x, T* d_y) {
10621062 *d_y += (y / h) * d_z;
10631063}
10641064
1065+ // 7. Special Functions
1066+ #if __cplusplus >= 201703L
1067+ template <typename T>
1068+ CUDA_HOST_DEVICE inline T clad_digamma (T x) {
1069+ if (x <= 0.0 ) {
1070+ if (x == ::std::floor (x)) return (T)NAN;
1071+ return clad_digamma (1.0 - x) - ::std::acos ((T)-1.0 ) / ::std::tan (::std::acos ((T)-1.0 ) * x);
1072+ }
1073+ T result = 0.0 ;
1074+ while (x < 8.0 ) {
1075+ result -= 1.0 / x;
1076+ x += 1.0 ;
1077+ }
1078+ T inv_x = 1.0 / x;
1079+ T inv_x2 = inv_x * inv_x;
1080+ result += ::std::log (x) - 0.5 * inv_x
1081+ - inv_x2 * (1.0 /12.0
1082+ - inv_x2 * (1.0 /120.0
1083+ - inv_x2 * (1.0 /252.0
1084+ - inv_x2 * (1.0 /240.0 ))));
1085+ return result;
1086+ }
1087+
1088+ template <typename T, typename dT>
1089+ CUDA_HOST_DEVICE ValueAndPushforward<T, dT> beta_pushforward (T x, T y, dT d_x, dT d_y) {
1090+ T b = ::std::beta (x, y);
1091+ T psi_xy = clad_digamma (x + y);
1092+ dT pushforward = 0 ;
1093+ if (d_x) pushforward += b * (clad_digamma (x) - psi_xy) * d_x;
1094+ if (d_y) pushforward += b * (clad_digamma (y) - psi_xy) * d_y;
1095+ return {b, pushforward};
1096+ }
1097+
1098+ template <typename T, typename U>
1099+ CUDA_HOST_DEVICE void beta_pullback (T x, T y, U d_z, T* d_x, T* d_y) {
1100+ T b = ::std::beta (x, y);
1101+ T psi_xy = clad_digamma (x + y);
1102+ if (d_x) *d_x += b * (clad_digamma (x) - psi_xy) * d_z;
1103+ if (d_y) *d_y += b * (clad_digamma (y) - psi_xy) * d_z;
1104+ }
1105+ #endif
1106+
10651107} // namespace std
10661108
10671109CUDA_HOST_DEVICE inline ValueAndPushforward<float , float >
@@ -1327,6 +1369,12 @@ using std::pow_pullback;
13271369using std::pow_pushforward;
13281370using std::sqrt_pushforward;
13291371
1372+ // 7. Special Functions
1373+ #if __cplusplus >= 201703L
1374+ using std::beta_pullback;
1375+ using std::beta_pushforward;
1376+ #endif
1377+
13301378namespace class_functions {
13311379template <typename T, typename U>
13321380void constructor_pullback (ValueAndPushforward<T, U> rhs,
0 commit comments