@@ -1320,5 +1320,73 @@ inline float modified_bessel_k0_forward(T x) {
13201320 return ::metal::exp (-x) * (0.5 * (b - p)) / ::metal::sqrt (x);
13211321} // modified_bessel_k0_forward(T x)
13221322
1323+ template <typename T>
1324+ float modified_bessel_k1_forward (T x) {
1325+ constexpr float A[] = {
1326+ -7.02386347938628759343e-18 ,
1327+ -2.42744985051936593393e-15 ,
1328+ -6.66690169419932900609e-13 ,
1329+ -1.41148839263352776110e-10 ,
1330+ -2.21338763073472585583e-08 ,
1331+ -2.43340614156596823496e-06 ,
1332+ -1.73028895751305206302e-04 ,
1333+ -6.97572385963986435018e-03 ,
1334+ -1.22611180822657148235e-01 ,
1335+ -3.53155960776544875667e-01 ,
1336+ +1.52530022733894777053e+00 ,
1337+ };
1338+
1339+ constexpr float B[] = {
1340+ -5.75674448366501715755e-18 , +1.79405087314755922667e-17 ,
1341+ -5.68946255844285935196e-17 , +1.83809354436663880070e-16 ,
1342+ -6.05704724837331885336e-16 , +2.03870316562433424052e-15 ,
1343+ -7.01983709041831346144e-15 , +2.47715442448130437068e-14 ,
1344+ -8.97670518232499435011e-14 , +3.34841966607842919884e-13 ,
1345+ -1.28917396095102890680e-12 , +5.13963967348173025100e-12 ,
1346+ -2.12996783842756842877e-11 , +9.21831518760500529508e-11 ,
1347+ -4.19035475934189648750e-10 , +2.01504975519703286596e-09 ,
1348+ -1.03457624656780970260e-08 , +5.74108412545004946722e-08 ,
1349+ -3.50196060308781257119e-07 , +2.40648494783721712015e-06 ,
1350+ -1.93619797416608296024e-05 , +1.95215518471351631108e-04 ,
1351+ -2.85781685962277938680e-03 , +1.03923736576817238437e-01 ,
1352+ +2.72062619048444266945e+00 ,
1353+ };
1354+
1355+ if (x == 0.0 ) {
1356+ return INFINITY;
1357+ }
1358+
1359+ if (x < 0.0 ) {
1360+ return NAN;
1361+ }
1362+
1363+ float p;
1364+ float q = 0.0 ;
1365+
1366+ if (x <= 2.0 ) {
1367+ float a = A[0 ];
1368+
1369+ for (uint8_t index = 1 ; index < 11 ; index++) {
1370+ p = q;
1371+ q = a;
1372+ a = (x * x - T (2.0 )) * q - p + A[index];
1373+ }
1374+
1375+ return ::metal::precise::log (T (0.5 ) * x) * modified_bessel_i1_forward (x) +
1376+ 0.5 * (a - p) / x;
1377+ }
1378+
1379+ float b = B[0 ];
1380+
1381+ for (uint8_t index = 1 ; index < 25 ; index++) {
1382+ p = q;
1383+ q = b;
1384+ b = (8.0 / x - 2.0 ) * q - p + B[index];
1385+ }
1386+
1387+ return ::metal::precise::exp (-x) * (0.5 * (b - p)) /
1388+ ::metal::precise::sqrt (x);
1389+ }
1390+
13231391} // namespace metal
13241392} // namespace c10
0 commit comments