Skip to content

Commit 0ed3421

Browse files
dccipytorchmergebot
authored andcommitted
[MPS] Add support for modified_bessel_k1 to eager and inductor. (pytorch#149687)
Pull Request resolved: pytorch#149687 Approved by: https://github.com/malfet
1 parent 0a396a8 commit 0ed3421

File tree

7 files changed

+81
-2
lines changed

7 files changed

+81
-2
lines changed

aten/src/ATen/native/mps/kernels/SpecialOps.metal

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ DEFINE_UNARY_FLOATING_FUNCTOR(bessel_j1_forward);
88
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i0_forward);
99
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_i1_forward);
1010
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_k0_forward);
11+
DEFINE_UNARY_FLOATING_FUNCTOR(modified_bessel_k1_forward);
1112
DEFINE_UNARY_FLOATING_FUNCTOR(i0);
1213
DEFINE_UNARY_FLOATING_FUNCTOR(i0e);
1314
DEFINE_UNARY_FLOATING_FUNCTOR(i1);
@@ -53,6 +54,7 @@ struct bessel_y1_forward_functor {
5354
REGISTER_UNARY_OP(modified_bessel_i0_forward, DTI, DTO); \
5455
REGISTER_UNARY_OP(modified_bessel_i1_forward, DTI, DTO); \
5556
REGISTER_UNARY_OP(modified_bessel_k0_forward, DTI, DTO); \
57+
REGISTER_UNARY_OP(modified_bessel_k1_forward, DTI, DTO); \
5658
REGISTER_UNARY_OP(bessel_y0_forward, DTI, DTO); \
5759
REGISTER_UNARY_OP(bessel_y1_forward, DTI, DTO); \
5860
REGISTER_UNARY_OP(i0, DTI, DTO); \

aten/src/ATen/native/mps/operations/SpecialOps.mm

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ static void modified_bessel_k0_kernel_mps(TensorIteratorBase& iter) {
5656
lib.exec_unary_kernel(iter, "modified_bessel_k0_forward");
5757
}
5858

59+
static void modified_bessel_k1_kernel_mps(TensorIteratorBase& iter) {
60+
lib.exec_unary_kernel(iter, "modified_bessel_k1_forward");
61+
}
62+
5963
static void bessel_y0_kernel_mps(TensorIteratorBase& iter) {
6064
lib.exec_unary_kernel(iter, "bessel_y0_forward");
6165
}
@@ -73,6 +77,7 @@ static void bessel_y1_kernel_mps(TensorIteratorBase& iter) {
7377
REGISTER_DISPATCH(special_modified_bessel_i0_stub, &modified_bessel_i0_kernel_mps)
7478
REGISTER_DISPATCH(special_modified_bessel_i1_stub, &modified_bessel_i1_kernel_mps)
7579
REGISTER_DISPATCH(special_modified_bessel_k0_stub, &modified_bessel_k0_kernel_mps)
80+
REGISTER_DISPATCH(special_modified_bessel_k1_stub, &modified_bessel_k1_kernel_mps)
7681
REGISTER_DISPATCH(special_bessel_y0_stub, &bessel_y0_kernel_mps)
7782
REGISTER_DISPATCH(special_bessel_y1_stub, &bessel_y1_kernel_mps)
7883
REGISTER_DISPATCH(special_spherical_bessel_j0_stub, &spherical_bessel_j0_kernel_mps)

aten/src/ATen/native/native_functions.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15486,7 +15486,7 @@
1548615486

1548715487
- func: special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
1548815488
dispatch:
15489-
CPU, CUDA: special_modified_bessel_k1_out
15489+
CPU, CUDA, MPS: special_modified_bessel_k1_out
1549015490
python_module: special
1549115491
structured_inherits: TensorIteratorBase
1549215492
structured: True

c10/metal/special_math.h

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,5 +1320,73 @@ inline float modified_bessel_k0_forward(T x) {
13201320
return ::metal::exp(-x) * (0.5 * (b - p)) / ::metal::sqrt(x);
13211321
} // modified_bessel_k0_forward(T x)
13221322

1323+
template <typename T>
1324+
float modified_bessel_k1_forward(T x) {
1325+
constexpr float A[] = {
1326+
-7.02386347938628759343e-18,
1327+
-2.42744985051936593393e-15,
1328+
-6.66690169419932900609e-13,
1329+
-1.41148839263352776110e-10,
1330+
-2.21338763073472585583e-08,
1331+
-2.43340614156596823496e-06,
1332+
-1.73028895751305206302e-04,
1333+
-6.97572385963986435018e-03,
1334+
-1.22611180822657148235e-01,
1335+
-3.53155960776544875667e-01,
1336+
+1.52530022733894777053e+00,
1337+
};
1338+
1339+
constexpr float B[] = {
1340+
-5.75674448366501715755e-18, +1.79405087314755922667e-17,
1341+
-5.68946255844285935196e-17, +1.83809354436663880070e-16,
1342+
-6.05704724837331885336e-16, +2.03870316562433424052e-15,
1343+
-7.01983709041831346144e-15, +2.47715442448130437068e-14,
1344+
-8.97670518232499435011e-14, +3.34841966607842919884e-13,
1345+
-1.28917396095102890680e-12, +5.13963967348173025100e-12,
1346+
-2.12996783842756842877e-11, +9.21831518760500529508e-11,
1347+
-4.19035475934189648750e-10, +2.01504975519703286596e-09,
1348+
-1.03457624656780970260e-08, +5.74108412545004946722e-08,
1349+
-3.50196060308781257119e-07, +2.40648494783721712015e-06,
1350+
-1.93619797416608296024e-05, +1.95215518471351631108e-04,
1351+
-2.85781685962277938680e-03, +1.03923736576817238437e-01,
1352+
+2.72062619048444266945e+00,
1353+
};
1354+
1355+
if (x == 0.0) {
1356+
return INFINITY;
1357+
}
1358+
1359+
if (x < 0.0) {
1360+
return NAN;
1361+
}
1362+
1363+
float p;
1364+
float q = 0.0;
1365+
1366+
if (x <= 2.0) {
1367+
float a = A[0];
1368+
1369+
for (uint8_t index = 1; index < 11; index++) {
1370+
p = q;
1371+
q = a;
1372+
a = (x * x - T(2.0)) * q - p + A[index];
1373+
}
1374+
1375+
return ::metal::precise::log(T(0.5) * x) * modified_bessel_i1_forward(x) +
1376+
0.5 * (a - p) / x;
1377+
}
1378+
1379+
float b = B[0];
1380+
1381+
for (uint8_t index = 1; index < 25; index++) {
1382+
p = q;
1383+
q = b;
1384+
b = (8.0 / x - 2.0) * q - p + B[index];
1385+
}
1386+
1387+
return ::metal::precise::exp(-x) * (0.5 * (b - p)) /
1388+
::metal::precise::sqrt(x);
1389+
}
1390+
13231391
} // namespace metal
13241392
} // namespace c10

test/inductor/test_mps_basic.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def test_cast(self, dtype):
9898
"modified_bessel_i0",
9999
"modified_bessel_i1",
100100
"modified_bessel_k0",
101+
"modified_bessel_k1",
101102
"entr",
102103
]
103104

test/test_mps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,6 @@ def mps_ops_modifier(ops):
663663
'special.hermite_polynomial_he': None,
664664
'special.laguerre_polynomial_l': None,
665665
'special.log_ndtr': None,
666-
'special.modified_bessel_k1': None,
667666
'special.ndtri': None,
668667
'special.scaled_modified_bessel_k0': None,
669668
'special.scaled_modified_bessel_k1': None,

torch/_inductor/codegen/mps.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,10 @@ def modified_bessel_i1(x: CSEVariable) -> str:
414414
def modified_bessel_k0(x: CSEVariable) -> str:
415415
return f"c10::metal::modified_bessel_k0_forward({x})"
416416

417+
@staticmethod
418+
def modified_bessel_k1(x: CSEVariable) -> str:
419+
return f"c10::metal::modified_bessel_k1_forward({x})"
420+
417421

418422
MetalOverrides._initialize_pointwise_overrides("mps")
419423

0 commit comments

Comments
 (0)