|
1 | 1 | #include <c10/metal/special_math.h> |
| 2 | +#include <c10/metal/utils.h> |
2 | 3 | #include <metal_stdlib> |
3 | 4 | using namespace metal; |
4 | 5 |
|
@@ -74,6 +75,15 @@ struct nextafter_functor { |
74 | 75 | } |
75 | 76 | }; |
76 | 77 |
|
| 78 | +struct polar_functor { |
| 79 | + template <typename U> |
| 80 | + using ret_type = c10::metal::vec2type_t<U>; |
| 81 | + template <typename T> |
| 82 | + inline ret_type<T> operator()(const T a, const T b) { |
| 83 | + return ret_type<T>(a * cos(b), a * sin(b)); |
| 84 | + } |
| 85 | +}; |
| 86 | + |
77 | 87 | // Future BinaryTensorIterator |
78 | 88 | template <typename T, typename F> |
79 | 89 | using result_of = decltype(::metal::declval<F>()( |
@@ -153,22 +163,8 @@ REGISTER_BINARY_INDEXING_OP(zeta, bfloat); |
153 | 163 | #endif |
154 | 164 |
|
155 | 165 | // Complex binary functions |
156 | | -template <typename T> |
157 | | -kernel void polar( |
158 | | - constant void* abs_ [[buffer(0)]], |
159 | | - constant void* angle_ [[buffer(1)]], |
160 | | - device void* out_ [[buffer(2)]], |
161 | | - constant uint3* offsets [[buffer(3)]], |
162 | | - uint tid [[thread_position_in_grid]]) { |
163 | | - device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x); |
164 | | - constant T* angle = (constant T*)((constant uint8_t*)angle_ + offsets[tid].z); |
165 | | - constant T* abs = (constant T*)((constant uint8_t*)abs_ + offsets[tid].y); |
166 | | - out[0] = abs[0] * cos(angle[0]); |
167 | | - out[1] = abs[0] * sin(angle[0]); |
168 | | -} |
169 | | - |
170 | | -REGISTER_BINARY_OP(polar, float); |
171 | | -REGISTER_BINARY_OP(polar, half); |
| 166 | +REGISTER_BINARY_INDEXING_OP(polar, float); |
| 167 | +REGISTER_BINARY_INDEXING_OP(polar, half); |
172 | 168 |
|
173 | 169 | template <typename T> |
174 | 170 | kernel void complex_mul( |
|
0 commit comments