Skip to content

Commit 2ef51cf

Browse files
malfetpytorchmergebot
authored andcommitted
[BE][MPS] Infer results of functor (pytorch#147182)
Do not assume that functor will return the same results as its arguments, but rather dynamically infer it using `decltype` and `::metal::declval` This is a no-op that prepares for migration of `copysign` of integral arguments, that would return a float Pull Request resolved: pytorch#147182 Approved by: https://github.com/dcci
1 parent 331d5cf commit 2ef51cf

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

aten/src/ATen/native/mps/kernels/BinaryKernel.metal

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,22 @@ struct nextafter_functor {
6666
}
6767
};
6868

69+
// Future BinaryTensorIterator
70+
template <typename T, typename F>
71+
using result_of = decltype(::metal::declval<F>()(
72+
::metal::declval<T>(),
73+
::metal::declval<T>()));
74+
6975
template <typename T, typename F>
7076
kernel void binary_indexing(
7177
constant void* input_ [[buffer(0)]],
7278
constant void* other_ [[buffer(1)]],
7379
device void* out_ [[buffer(2)]],
7480
constant uint3* offsets [[buffer(3)]],
7581
uint tid [[thread_position_in_grid]]) {
76-
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
77-
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
78-
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
82+
auto out = (device result_of<T, F>*)((device uint8_t*)out_ + offsets[tid].x);
83+
auto input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
84+
auto other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
7985
F f;
8086
*out = f(*input, *other);
8187
}
@@ -84,7 +90,7 @@ template <typename T, typename F>
8490
kernel void binary_dense(
8591
constant T* input [[buffer(0)]],
8692
constant T* other [[buffer(1)]],
87-
device T* out [[buffer(2)]],
93+
device result_of<T, F>* out [[buffer(2)]],
8894
uint tid [[thread_position_in_grid]]) {
8995
F f;
9096
out[tid] = f(input[tid], other[tid]);
@@ -102,7 +108,7 @@ kernel void binary_dense(
102108
binary_dense<DTYPE, NAME##_functor>( \
103109
constant DTYPE * input_, \
104110
constant DTYPE * other_, \
105-
device DTYPE * out_, \
111+
device result_of<DTYPE, NAME##_functor> * out_, \
106112
uint tid)
107113

108114
template <typename T>

0 commit comments

Comments
 (0)