Skip to content

Commit 278ffd8

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS][BE] Add copysign integral flavors as functor (pytorch#147183)
Pull Request resolved: pytorch#147183 Approved by: https://github.com/dcci ghstack dependencies: pytorch#147182
1 parent 2ef51cf commit 278ffd8

File tree

2 files changed

+16
-31
lines changed

2 files changed

+16
-31
lines changed

aten/src/ATen/native/mps/kernels/BinaryKernel.metal

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,17 @@ struct fmin_functor {
1818

1919
struct copysign_functor {
2020
template <typename T>
21-
inline T operator()(const T a, const T b) {
21+
inline enable_if_t<is_floating_point_v<T>, T> operator()(
22+
const T a,
23+
const T b) {
2224
return static_cast<T>(::metal::copysign(a, b));
2325
}
26+
template <typename T>
27+
inline enable_if_t<!is_floating_point_v<T>, float> operator()(
28+
const T a,
29+
const T b) {
30+
return ::metal::copysign(static_cast<float>(a), static_cast<float>(b));
31+
}
2432
};
2533

2634
struct zeta_functor {
@@ -111,20 +119,6 @@ kernel void binary_dense(
111119
device result_of<DTYPE, NAME##_functor> * out_, \
112120
uint tid)
113121

114-
template <typename T>
115-
kernel void copysign_integral(
116-
constant void* input_ [[buffer(0)]],
117-
constant void* other_ [[buffer(1)]],
118-
device void* out_ [[buffer(2)]],
119-
constant uint3* offsets [[buffer(3)]],
120-
uint tid [[thread_position_in_grid]]) {
121-
device float* out = (device float*)((device uint8_t*)out_ + offsets[tid].x);
122-
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
123-
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
124-
125-
*out = copysign(static_cast<float>(*input), static_cast<float>(*other));
126-
}
127-
128122
#define REGISTER_BINARY_OP(NAME, DTYPE) \
129123
template [[host_name(#NAME "_" #DTYPE)]] kernel void NAME<DTYPE>( \
130124
constant void* input_, \
@@ -133,17 +127,14 @@ kernel void copysign_integral(
133127
constant uint3* offsets, \
134128
uint tid)
135129

136-
#define REGISTER_COPYSIGN_INTEGRAL_OP(DTYPE) \
137-
template [[host_name("copysign_" #DTYPE)]] kernel void \
138-
copysign_integral<DTYPE>( \
139-
constant void* input_ [[buffer(0)]], \
140-
constant void* other_ [[buffer(1)]], \
141-
device void* out_ [[buffer(2)]], \
142-
constant uint3* offsets [[buffer(3)]], \
143-
uint tid [[thread_position_in_grid]]);
144-
130+
REGISTER_BINARY_INDEXING_OP(copysign, long);
131+
REGISTER_BINARY_INDEXING_OP(copysign, int);
145132
REGISTER_BINARY_INDEXING_OP(copysign, float);
146133
REGISTER_BINARY_INDEXING_OP(copysign, half);
134+
REGISTER_BINARY_INDEXING_OP(copysign, short);
135+
REGISTER_BINARY_INDEXING_OP(copysign, uchar);
136+
REGISTER_BINARY_INDEXING_OP(copysign, char);
137+
REGISTER_BINARY_INDEXING_OP(copysign, bool);
147138
REGISTER_BINARY_INDEXING_OP(fmax, float);
148139
REGISTER_BINARY_INDEXING_OP(fmax, half);
149140
REGISTER_BINARY_INDEXING_OP(fmin, float);
@@ -160,12 +151,6 @@ REGISTER_BINARY_INDEXING_OP(fmin, bfloat);
160151
REGISTER_BINARY_INDEXING_OP(nextafter, bfloat);
161152
REGISTER_BINARY_INDEXING_OP(zeta, bfloat);
162153
#endif
163-
REGISTER_COPYSIGN_INTEGRAL_OP(int);
164-
REGISTER_COPYSIGN_INTEGRAL_OP(long);
165-
REGISTER_COPYSIGN_INTEGRAL_OP(short);
166-
REGISTER_COPYSIGN_INTEGRAL_OP(char);
167-
REGISTER_COPYSIGN_INTEGRAL_OP(uchar);
168-
REGISTER_COPYSIGN_INTEGRAL_OP(bool);
169154

170155
// Complex binary functions
171156
template <typename T>

aten/src/ATen/native/mps/operations/BinaryKernel.mm

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ static void fmin_mps_kernel(TensorIteratorBase& iter) {
110110
}
111111

112112
static void copysign_mps_kernel(TensorIteratorBase& iter) {
113-
mps::binary_mps_impl(iter, "copysign", false);
113+
mps::binary_mps_impl(iter, "copysign");
114114
}
115115

116116
static void nextafter_mps_kernel(TensorIteratorBase& iter) {

0 commit comments

Comments
 (0)