Skip to content

Commit 10bc8f2

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS][BE] Migrate polar to use functor (pytorch#147184)
Pull Request resolved: pytorch#147184 Approved by: https://github.com/dcci ghstack dependencies: pytorch#147182, pytorch#147183
1 parent 278ffd8 commit 10bc8f2

File tree

2 files changed

+13
-17
lines changed

2 files changed

+13
-17
lines changed

aten/src/ATen/native/mps/kernels/BinaryKernel.metal

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <c10/metal/special_math.h>
2+
#include <c10/metal/utils.h>
23
#include <metal_stdlib>
34
using namespace metal;
45

@@ -74,6 +75,15 @@ struct nextafter_functor {
7475
}
7576
};
7677

78+
struct polar_functor {
79+
template <typename U>
80+
using ret_type = c10::metal::vec2type_t<U>;
81+
template <typename T>
82+
inline ret_type<T> operator()(const T a, const T b) {
83+
return ret_type<T>(a * cos(b), a * sin(b));
84+
}
85+
};
86+
7787
// Future BinaryTensorIterator
7888
template <typename T, typename F>
7989
using result_of = decltype(::metal::declval<F>()(
@@ -153,22 +163,8 @@ REGISTER_BINARY_INDEXING_OP(zeta, bfloat);
153163
#endif
154164

155165
// Complex binary functions
156-
template <typename T>
157-
kernel void polar(
158-
constant void* abs_ [[buffer(0)]],
159-
constant void* angle_ [[buffer(1)]],
160-
device void* out_ [[buffer(2)]],
161-
constant uint3* offsets [[buffer(3)]],
162-
uint tid [[thread_position_in_grid]]) {
163-
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
164-
constant T* angle = (constant T*)((constant uint8_t*)angle_ + offsets[tid].z);
165-
constant T* abs = (constant T*)((constant uint8_t*)abs_ + offsets[tid].y);
166-
out[0] = abs[0] * cos(angle[0]);
167-
out[1] = abs[0] * sin(angle[0]);
168-
}
169-
170-
REGISTER_BINARY_OP(polar, float);
171-
REGISTER_BINARY_OP(polar, half);
166+
REGISTER_BINARY_INDEXING_OP(polar, float);
167+
REGISTER_BINARY_INDEXING_OP(polar, half);
172168

173169
template <typename T>
174170
kernel void complex_mul(

aten/src/ATen/native/mps/operations/BinaryKernel.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ static void zeta_mps_kernel(TensorIteratorBase& iter) {
141141
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
142142
auto iter = TensorIteratorConfig().add_output(output_as_real).add_input(abs).add_input(angle).build();
143143

144-
mps::binary_mps_impl(iter, "polar", false);
144+
mps::binary_mps_impl(iter, "polar");
145145
return output;
146146
}
147147

0 commit comments

Comments
 (0)