Skip to content

Commit 8c22bde

Browse files
committed
Add analytical derivative support for std::beta
1 parent 34daed3 commit 8c22bde

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

include/clad/Differentiator/BuiltinDerivatives.h

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,48 @@ CUDA_HOST_DEVICE void hypot_pullback(T x, T y, U d_z, T* d_x, T* d_y) {
10621062
*d_y += (y / h) * d_z;
10631063
}
10641064

1065+
// 7. Special Functions
1066+
#if __cplusplus >= 201703L
1067+
template <typename T>
1068+
CUDA_HOST_DEVICE inline T clad_digamma(T x) {
1069+
if (x <= 0.0) {
1070+
if (x == ::std::floor(x)) return (T)NAN;
1071+
return clad_digamma(1.0 - x) - ::std::acos((T)-1.0) / ::std::tan(::std::acos((T)-1.0) * x);
1072+
}
1073+
T result = 0.0;
1074+
while (x < 8.0) {
1075+
result -= 1.0 / x;
1076+
x += 1.0;
1077+
}
1078+
T inv_x = 1.0 / x;
1079+
T inv_x2 = inv_x * inv_x;
1080+
result += ::std::log(x) - 0.5 * inv_x
1081+
- inv_x2 * (1.0/12.0
1082+
- inv_x2 * (1.0/120.0
1083+
- inv_x2 * (1.0/252.0
1084+
- inv_x2 * (1.0/240.0))));
1085+
return result;
1086+
}
1087+
1088+
template <typename T, typename dT>
1089+
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> beta_pushforward(T x, T y, dT d_x, dT d_y) {
1090+
T b = ::std::beta(x, y);
1091+
T psi_xy = clad_digamma(x + y);
1092+
dT pushforward = 0;
1093+
if (d_x) pushforward += b * (clad_digamma(x) - psi_xy) * d_x;
1094+
if (d_y) pushforward += b * (clad_digamma(y) - psi_xy) * d_y;
1095+
return {b, pushforward};
1096+
}
1097+
1098+
template <typename T, typename U>
1099+
CUDA_HOST_DEVICE void beta_pullback(T x, T y, U d_z, T* d_x, T* d_y) {
1100+
T b = ::std::beta(x, y);
1101+
T psi_xy = clad_digamma(x + y);
1102+
if (d_x) *d_x += b * (clad_digamma(x) - psi_xy) * d_z;
1103+
if (d_y) *d_y += b * (clad_digamma(y) - psi_xy) * d_z;
1104+
}
1105+
#endif
1106+
10651107
} // namespace std
10661108

10671109
CUDA_HOST_DEVICE inline ValueAndPushforward<float, float>
@@ -1327,6 +1369,12 @@ using std::pow_pullback;
13271369
using std::pow_pushforward;
13281370
using std::sqrt_pushforward;
13291371

1372+
// 7. Special Functions
1373+
#if __cplusplus >= 201703L
1374+
using std::beta_pullback;
1375+
using std::beta_pushforward;
1376+
#endif
1377+
13301378
namespace class_functions {
13311379
template <typename T, typename U>
13321380
void constructor_pullback(ValueAndPushforward<T, U> rhs,

test/FirstDerivative/BuiltinDerivatives.C

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,6 +515,14 @@ double f_custom_min(double x, double y) { return std::min(x, y, std::greater<dou
515515
// CHECK-NEXT: return _t0.pushforward;
516516
// CHECK-NEXT: }
517517

518+
double f_beta(double x, double y) { return std::beta(x, y); }
519+
// CHECK: double f_beta_darg0(double x, double y) {
520+
// CHECK-NEXT: double _d_x = 1;
521+
// CHECK-NEXT: double _d_y = 0;
522+
// CHECK-NEXT: {{.*}}ValueAndPushforward<double, double> _t0 = {{.*}}beta_pushforward(x, y, _d_x, _d_y);
523+
// CHECK-NEXT: return _t0.pushforward;
524+
// CHECK-NEXT: }
525+
518526
int main () { //expected-no-diagnostics
519527
float f_result[2];
520528
double d_result[2];
@@ -695,5 +703,8 @@ int main () { //expected-no-diagnostics
695703
auto d_custom_min = clad::differentiate(f_custom_min, 0);
696704
printf("Result is = %.6f\n", d_custom_min.execute(2, 3)); // CHECK-EXEC: Result is = 0.000000
697705

706+
auto d_beta = clad::differentiate(f_beta, 0);
707+
printf("Result is = %.6f\n", d_beta.execute(2.0, 3.0)); // CHECK-EXEC: Result is = -0.090278
708+
698709
return 0;
699710
}

0 commit comments

Comments
 (0)