@@ -18,9 +18,17 @@ struct fmin_functor {
1818
1919struct copysign_functor {
2020 template <typename T>
21- inline T operator ()(const T a, const T b) {
21+ inline enable_if_t <is_floating_point_v<T>, T> operator ()(
22+ const T a,
23+ const T b) {
2224 return static_cast <T>(::metal::copysign (a, b));
2325 }
26+ template <typename T>
27+ inline enable_if_t <!is_floating_point_v<T>, float > operator ()(
28+ const T a,
29+ const T b) {
30+ return ::metal::copysign (static_cast <float >(a), static_cast <float >(b));
31+ }
2432};
2533
2634struct zeta_functor {
@@ -111,20 +119,6 @@ kernel void binary_dense(
111119 device result_of<DTYPE, NAME##_functor> * out_, \
112120 uint tid)
113121
114- template <typename T>
115- kernel void copysign_integral (
116- constant void * input_ [[buffer(0 )]],
117- constant void* other_ [[buffer(1 )]],
118- device void* out_ [[buffer(2 )]],
119- constant uint3* offsets [[buffer(3 )]],
120- uint tid [[thread_position_in_grid]]) {
121- device float * out = (device float *)((device uint8_t *)out_ + offsets[tid].x );
122- constant T* input = (constant T*)((constant uint8_t *)input_ + offsets[tid].y );
123- constant T* other = (constant T*)((constant uint8_t *)other_ + offsets[tid].z );
124-
125- *out = copysign (static_cast <float >(*input), static_cast <float >(*other));
126- }
127-
128122#define REGISTER_BINARY_OP (NAME, DTYPE ) \
129123 template [[host_name(#NAME " _" #DTYPE)]] kernel void NAME<DTYPE>( \
130124 constant void * input_, \
@@ -133,17 +127,14 @@ kernel void copysign_integral(
133127 constant uint3* offsets, \
134128 uint tid)
135129
136- #define REGISTER_COPYSIGN_INTEGRAL_OP (DTYPE ) \
137- template [[host_name(" copysign_" #DTYPE)]] kernel void \
138- copysign_integral<DTYPE>( \
139- constant void * input_ [[buffer(0 )]], \
140- constant void * other_ [[buffer(1 )]], \
141- device void * out_ [[buffer(2 )]], \
142- constant uint3* offsets [[buffer(3 )]], \
143- uint tid [[thread_position_in_grid]]);
144-
130+ REGISTER_BINARY_INDEXING_OP (copysign, long );
131+ REGISTER_BINARY_INDEXING_OP (copysign, int );
145132REGISTER_BINARY_INDEXING_OP (copysign, float );
146133REGISTER_BINARY_INDEXING_OP (copysign, half);
134+ REGISTER_BINARY_INDEXING_OP (copysign, short );
135+ REGISTER_BINARY_INDEXING_OP (copysign, uchar);
136+ REGISTER_BINARY_INDEXING_OP (copysign, char );
137+ REGISTER_BINARY_INDEXING_OP (copysign, bool );
147138REGISTER_BINARY_INDEXING_OP (fmax, float );
148139REGISTER_BINARY_INDEXING_OP (fmax, half);
149140REGISTER_BINARY_INDEXING_OP (fmin, float );
@@ -160,12 +151,6 @@ REGISTER_BINARY_INDEXING_OP(fmin, bfloat);
160151REGISTER_BINARY_INDEXING_OP (nextafter, bfloat);
161152REGISTER_BINARY_INDEXING_OP (zeta, bfloat);
162153#endif
163- REGISTER_COPYSIGN_INTEGRAL_OP (int );
164- REGISTER_COPYSIGN_INTEGRAL_OP (long );
165- REGISTER_COPYSIGN_INTEGRAL_OP (short );
166- REGISTER_COPYSIGN_INTEGRAL_OP (char );
167- REGISTER_COPYSIGN_INTEGRAL_OP (uchar);
168- REGISTER_COPYSIGN_INTEGRAL_OP (bool );
169154
170155// Complex binary functions
171156template <typename T>
0 commit comments