Skip to content

Commit 9566db9

Browse files
committed
Add analytical derivative support for std::beta
1 parent 34daed3 commit 9566db9

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed

include/clad/Differentiator/BuiltinDerivatives.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,61 @@ CUDA_HOST_DEVICE void hypot_pullback(T x, T y, U d_z, T* d_x, T* d_y) {
10621062
*d_y += (y / h) * d_z;
10631063
}
10641064

1065+
// 7. Special Functions
1066+
#if __cplusplus >= 201703L
1067+
template <typename T> CUDA_HOST_DEVICE inline T clad_beta_primal(T x, T y) {
1068+
#if defined(__cpp_lib_math_special_functions)
1069+
return ::std::beta(x, y);
1070+
#else
1071+
return ::std::tgamma(x) * ::std::tgamma(y) / ::std::tgamma(x + y);
1072+
#endif
1073+
}
1074+
template <typename T> CUDA_HOST_DEVICE inline T clad_digamma(T x) {
1075+
if (x <= 0.0) {
1076+
if (x == ::std::floor(x))
1077+
return (T)NAN;
1078+
return clad_digamma(1.0 - x) -
1079+
::std::acos((T)-1.0) / ::std::tan(::std::acos((T)-1.0) * x);
1080+
}
1081+
T result = 0.0;
1082+
while (x < 8.0) {
1083+
result -= 1.0 / x;
1084+
x += 1.0;
1085+
}
1086+
T inv_x = 1.0 / x;
1087+
T inv_x2 = inv_x * inv_x;
1088+
result +=
1089+
::std::log(x) - 0.5 * inv_x -
1090+
inv_x2 * (1.0 / 12.0 -
1091+
inv_x2 * (1.0 / 120.0 -
1092+
inv_x2 * (1.0 / 252.0 - inv_x2 * (1.0 / 240.0))));
1093+
return result;
1094+
}
1095+
1096+
template <typename T, typename dT>
1097+
CUDA_HOST_DEVICE ValueAndPushforward<T, dT> beta_pushforward(T x, T y, dT d_x,
1098+
dT d_y) {
1099+
T b = clad_beta_primal(x, y);
1100+
T psi_xy = clad_digamma(x + y);
1101+
dT pushforward = 0;
1102+
if (d_x)
1103+
pushforward += b * (clad_digamma(x) - psi_xy) * d_x;
1104+
if (d_y)
1105+
pushforward += b * (clad_digamma(y) - psi_xy) * d_y;
1106+
return {b, pushforward};
1107+
}
1108+
1109+
template <typename T, typename U>
1110+
CUDA_HOST_DEVICE void beta_pullback(T x, T y, U d_z, T* d_x, T* d_y) {
1111+
T b = clad_beta_primal(x, y);
1112+
T psi_xy = clad_digamma(x + y);
1113+
if (d_x)
1114+
*d_x += b * (clad_digamma(x) - psi_xy) * d_z;
1115+
if (d_y)
1116+
*d_y += b * (clad_digamma(y) - psi_xy) * d_z;
1117+
}
1118+
#endif
1119+
10651120
} // namespace std
10661121

10671122
CUDA_HOST_DEVICE inline ValueAndPushforward<float, float>
@@ -1327,6 +1382,12 @@ using std::pow_pullback;
13271382
using std::pow_pushforward;
13281383
using std::sqrt_pushforward;
13291384

1385+
// 7. Special Functions
1386+
#if __cplusplus >= 201703L
1387+
using std::beta_pullback;
1388+
using std::beta_pushforward;
1389+
#endif
1390+
13301391
namespace class_functions {
13311392
template <typename T, typename U>
13321393
void constructor_pullback(ValueAndPushforward<T, U> rhs,

test/Features/stl-cmath.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,14 @@
123123
#include <cmath>
124124
#include <iostream>
125125
#include <iomanip>
126+
#if !defined(__cpp_lib_math_special_functions)
127+
namespace std {
128+
template <typename T>
129+
inline T beta(T x, T y) {
130+
return std::tgamma(x) * std::tgamma(y) / std::tgamma(x + y);
131+
}
132+
}
133+
#endif
126134

127135
template <typename T>
128136
T get_tolerance() {
@@ -295,6 +303,10 @@ DEFINE_FUNCTIONS(atanh) // x in [-1,1]
295303
//
296304
DEFINE_FUNCTIONS(erf) // x in (-inf,+inf)
297305

306+
template<typename T> T f_beta(T x){ return std::beta(x,(T)2.0); } // x in (0, +inf)
307+
inline float f_betaf(float x){ return std::beta(x, 2.0f); }
308+
inline long double f_betal(long double x){ return std::beta(x, 2.0L); }
309+
298310
int main() {
299311
// Absolute value
300312
CHECK(abs);
@@ -352,6 +364,7 @@ int main() {
352364

353365
// Error / Gamma functions
354366
CHECK_ALL(erf);
367+
CHECK_ALL_RANGE(beta, {0.1, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0});
355368

356369
return 0;
357370
}

test/FirstDerivative/BuiltinDerivatives.C

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,17 @@
66

77
#include "clad/Differentiator/Differentiator.h"
88
#include "../TestUtils.h"
9+
#include <cmath>
10+
11+
#if !defined(__cpp_lib_math_special_functions)
12+
namespace std {
13+
// Mock std::beta on Apple platforms so Clad's AST has a target to differentiate
14+
inline double beta(double x, double y) {
15+
return std::tgamma(x) * std::tgamma(y) / std::tgamma(x + y);
16+
}
17+
}
18+
#endif
19+
920
extern "C" int printf(const char* fmt, ...);
1021

1122

@@ -515,6 +526,14 @@ double f_custom_min(double x, double y) { return std::min(x, y, std::greater<dou
515526
// CHECK-NEXT: return _t0.pushforward;
516527
// CHECK-NEXT: }
517528

529+
double f_beta(double x, double y) { return std::beta(x, y); }
530+
// CHECK: double f_beta_darg0(double x, double y) {
531+
// CHECK-NEXT: double _d_x = 1;
532+
// CHECK-NEXT: double _d_y = 0;
533+
// CHECK-NEXT: {{.*}}ValueAndPushforward<double, double> _t0 = {{.*}}beta_pushforward(x, y, _d_x, _d_y);
534+
// CHECK-NEXT: return _t0.pushforward;
535+
// CHECK-NEXT: }
536+
518537
int main () { //expected-no-diagnostics
519538
float f_result[2];
520539
double d_result[2];
@@ -695,5 +714,8 @@ int main () { //expected-no-diagnostics
695714
auto d_custom_min = clad::differentiate(f_custom_min, 0);
696715
printf("Result is = %.6f\n", d_custom_min.execute(2, 3)); // CHECK-EXEC: Result is = 0.000000
697716

717+
auto d_beta = clad::differentiate(f_beta, 0);
718+
printf("Result is = %.6f\n", d_beta.execute(2.0, 3.0)); // CHECK-EXEC: Result is = -0.090278
719+
698720
return 0;
699721
}

0 commit comments

Comments
 (0)