Skip to content

Commit e2a5c42

Browse files
malfetpytorchmergebot
authored andcommitted
[BE][MPS] Build metal kernels of MacOS-14+ (pytorch#159733)
Which makes `#if __METAL_VERSION__ >= 310` guards for `bfloat` use support unnecessary. Rename `kernels_bfloat.metallib` into `kernels_basic` and remove custom build/selection logic. Part of pytorch#159275 Pull Request resolved: pytorch#159733 Approved by: https://github.com/dcci ghstack dependencies: pytorch#159731, pytorch#159732
1 parent 5116c49 commit e2a5c42

31 files changed

+11
-538
lines changed

aten/src/ATen/CMakeLists.txt

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -704,21 +704,17 @@ if(USE_MPS)
704704
if(CAN_COMPILE_METAL)
705705
foreach(SHADER ${native_mps_metal})
706706
cmake_path(GET SHADER STEM TGT_STEM)
707-
string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air")
708-
string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air")
707+
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
709708
list(APPEND AIR_BASIC ${TGT_BASIC})
710-
list(APPEND AIR_BFLOAT ${TGT_BFLOAT})
711-
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0")
712-
metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1")
709+
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1")
713710
endforeach()
714711
air_to_metallib(kernels_basic.metallib ${AIR_BASIC})
715-
air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT})
716712
add_custom_command(
717713
COMMAND echo "// $$(date)" > metallib_dummy.cpp
718-
DEPENDS kernels_basic.metallib kernels_bfloat.metallib
714+
DEPENDS kernels_basic.metallib
719715
OUTPUT metallib_dummy.cpp
720716
COMMENT "Updating metallibs timestamp")
721-
add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp)
717+
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
722718
else()
723719
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
724720
foreach(SHADER ${native_mps_metal})

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -953,8 +953,7 @@ void executeMPSAllocatorCallback(void* ptr, EventType event) override {}
953953
if (C10_UNLIKELY(!library)) {
954954
auto device = MPSDevice::getInstance()->device();
955955
NSError* error = nil;
956-
auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic";
957-
library = [device newLibraryWithData:getSectionData(section_name) error:&error];
956+
library = [device newLibraryWithData:getSectionData("metal_basic") error:&error];
958957
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
959958
}
960959
return library;

aten/src/ATen/native/mps/kernels/ActivationKernel.metal

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,15 @@ struct shrink_backward_functor {
3333

3434
REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float);
3535
REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
36-
#if __METAL_VERSION__ >= 310
3736
REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat);
38-
#endif
3937

4038
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
4139
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
42-
#if __METAL_VERSION__ >= 310
4340
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
44-
#endif
4541

4642
REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float);
4743
REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half);
48-
#if __METAL_VERSION__ >= 310
4944
REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat);
50-
#endif
5145

5246
struct hardsigmoid_functor {
5347
template <typename T>
@@ -67,15 +61,11 @@ struct hardsigmoid_backward_functor {
6761

6862
REGISTER_UNARY_OP(hardsigmoid, float, float);
6963
REGISTER_UNARY_OP(hardsigmoid, half, half);
70-
#if __METAL_VERSION__ >= 310
7164
REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat);
72-
#endif
7365

7466
REGISTER_BINARY_OP(hardsigmoid_backward, float, float);
7567
REGISTER_BINARY_OP(hardsigmoid_backward, half, half);
76-
#if __METAL_VERSION__ >= 310
7768
REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat);
78-
#endif
7969

8070
struct hardswish_functor {
8171
template <typename T>
@@ -103,15 +93,11 @@ struct hardswish_backward_functor {
10393

10494
REGISTER_UNARY_OP(hardswish, float, float);
10595
REGISTER_UNARY_OP(hardswish, half, half);
106-
#if __METAL_VERSION__ >= 310
10796
REGISTER_UNARY_OP(hardswish, bfloat, bfloat);
108-
#endif
10997

11098
REGISTER_BINARY_OP(hardswish_backward, float, float);
11199
REGISTER_BINARY_OP(hardswish_backward, half, half);
112-
#if __METAL_VERSION__ >= 310
113100
REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat);
114-
#endif
115101

116102
struct leaky_relu_functor {
117103
template <typename T>
@@ -135,12 +121,8 @@ struct leaky_relu_backward_functor {
135121

136122
REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float);
137123
REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half);
138-
#if __METAL_VERSION__ >= 310
139124
REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat);
140-
#endif
141125

142126
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float);
143127
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
144-
#if __METAL_VERSION__ >= 310
145128
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat);
146-
#endif

aten/src/ATen/native/mps/kernels/Amp.metal

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,18 +113,12 @@ kernel void ampUpdateScale(
113113

114114
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float);
115115
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half);
116-
#if __METAL_VERSION__ >= 310
117116
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat);
118-
#endif
119117

120118
INSTANTIATE_AMP_UPDATE_SCALE(float);
121119
INSTANTIATE_AMP_UPDATE_SCALE(half);
122-
#if __METAL_VERSION__ >= 310
123120
INSTANTIATE_AMP_UPDATE_SCALE(bfloat);
124-
#endif
125121

126122
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float);
127123
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half);
128-
#if __METAL_VERSION__ >= 310
129124
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat);
130-
#endif

aten/src/ATen/native/mps/kernels/Attention.metal

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -590,9 +590,7 @@ kernel void attention(
590590

591591
INSTANTIATE_SDPA_VECTOR_HEADS(float);
592592
INSTANTIATE_SDPA_VECTOR_HEADS(half);
593-
#if __METAL_VERSION__ >= 310
594593
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
595-
#endif
596594

597595
#define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \
598596
template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \
@@ -621,6 +619,4 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
621619

622620
INSTANTIATE_ATTN_SHAPES_HELPER(float);
623621
INSTANTIATE_ATTN_SHAPES_HELPER(half);
624-
#if __METAL_VERSION__ >= 310
625622
INSTANTIATE_ATTN_SHAPES_HELPER(bfloat);
626-
#endif

aten/src/ATen/native/mps/kernels/BinaryKernel.metal

Lines changed: 2 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -209,38 +209,9 @@ struct hermite_polynomial_he_functor {
209209
};
210210

211211
struct nextafter_functor {
212-
#if __METAL_VERSION__ < 310
213-
template <typename U>
214-
struct bit_type {};
215-
template <>
216-
struct bit_type<float> {
217-
using type = int;
218-
};
219-
template <>
220-
struct bit_type<half> {
221-
using type = short;
222-
};
223-
#endif
224212
template <typename T>
225213
inline T operator()(const T a, const T b) {
226-
#if __METAL_VERSION__ >= 310
227214
return static_cast<T>(::metal::nextafter(a, b));
228-
#else
229-
using U = typename bit_type<T>::type;
230-
if (a == b) {
231-
return a;
232-
}
233-
if (::metal::isunordered(a, b)) {
234-
return NAN;
235-
}
236-
if (a == 0) {
237-
constexpr auto eps = as_type<T>(static_cast<U>(1));
238-
return b > 0 ? eps : -eps;
239-
}
240-
auto bits = as_type<U>(a);
241-
(a > 0) ^ (a > b) ? bits++ : bits--;
242-
return as_type<T>(bits);
243-
#endif
244215
}
245216
};
246217

@@ -344,13 +315,6 @@ struct fmod_functor {
344315
}
345316
};
346317

347-
// Some helper defines
348-
#if __METAL_VERSION__ >= 310
349-
#define _METAL_310_PLUS(x) x
350-
#else
351-
#define _METAL_310_PLUS(x)
352-
#endif
353-
354318
#define REGISTER_INTEGER_BINARY_OP(NAME) \
355319
REGISTER_BINARY_OP(NAME, long, long); \
356320
REGISTER_BINARY_OP(NAME, int, int); \
@@ -370,12 +334,12 @@ struct fmod_functor {
370334
#define REGISTER_FLOAT_BINARY_OP(NAME) \
371335
REGISTER_BINARY_OP(NAME, float, float); \
372336
REGISTER_BINARY_OP(NAME, half, half); \
373-
_METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat))
337+
REGISTER_BINARY_OP(NAME, bfloat, bfloat)
374338

375339
#define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \
376340
REGISTER_OPMATH_BINARY_OP(NAME, float, float); \
377341
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
378-
_METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat))
342+
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
379343

380344
REGISTER_FLOAT_BINARY_OP(copysign);
381345
REGISTER_INT2FLOAT_BINARY_OP(copysign);
@@ -447,11 +411,9 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
447411
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
448412
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
449413

450-
#if __METAL_VERSION__ >= 310
451414
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
452415
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
453416
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);
454-
#endif
455417

456418
// Complex binary functions
457419
REGISTER_BINARY_OP(polar, float, float2);

aten/src/ATen/native/mps/kernels/Bucketization.metal

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,8 @@ REGISTER_SEARCHSORTED_OP(float, int);
180180
REGISTER_SEARCHSORTED_OP(float, long);
181181
REGISTER_SEARCHSORTED_OP(half, int);
182182
REGISTER_SEARCHSORTED_OP(half, long);
183-
#if __METAL_VERSION__ >= 310
184183
REGISTER_SEARCHSORTED_OP(bfloat, int);
185184
REGISTER_SEARCHSORTED_OP(bfloat, long);
186-
#endif
187185
REGISTER_SEARCHSORTED_OP(char, int);
188186
REGISTER_SEARCHSORTED_OP(char, long);
189187
REGISTER_SEARCHSORTED_OP(uchar, int);

aten/src/ATen/native/mps/kernels/Col2Im.metal

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,4 @@ kernel void col2im_kernel(
9696
INSTANTIATE_COL2IM(bool);
9797
INSTANTIATE_COL2IM(float);
9898
INSTANTIATE_COL2IM(half);
99-
#if __METAL_VERSION__ >= 310
10099
INSTANTIATE_COL2IM(bfloat);
101-
#endif

aten/src/ATen/native/mps/kernels/CrossKernel.metal

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ REGISTER_CROSS_FUNC(short);
2020
REGISTER_CROSS_FUNC(char);
2121
REGISTER_CROSS_FUNC(uchar);
2222
REGISTER_CROSS_FUNC(bool);
23-
#if __METAL_VERSION__ >= 310
2423
REGISTER_CROSS_FUNC(bfloat);
25-
#endif
2624

2725
template <typename T, typename U>
2826
kernel void cross(
@@ -68,6 +66,4 @@ REGISTER_CROSS_OP(short);
6866
REGISTER_CROSS_OP(char);
6967
REGISTER_CROSS_OP(uchar);
7068
REGISTER_CROSS_OP(bool);
71-
#if __METAL_VERSION__ >= 310
7269
REGISTER_CROSS_OP(bfloat);
73-
#endif

aten/src/ATen/native/mps/kernels/FusedOptimizerOps.metal

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
#include <metal_stdlib>
22

33
using metal::max;
4-
#if __METAL_VERSION__ >= 310
54
bfloat max(bfloat a, bfloat b) {
65
return a > b ? a : b;
76
}
8-
#endif
97

108
#define kmaxThreadGroups 32
119
#define kmaxTensors 32
@@ -306,11 +304,9 @@ REGISTER_ADAM_OPS_QUART(float, float);
306304
REGISTER_ADAM_OPS_QUART(float, half);
307305
REGISTER_ADAM_OPS_QUART(half, float);
308306
REGISTER_ADAM_OPS_QUART(half, half);
309-
#if __METAL_VERSION__ >= 310
310307
REGISTER_ADAM_OPS_QUART(float, bfloat);
311308
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
312309
REGISTER_ADAM_OPS_QUART(bfloat, float);
313-
#endif
314310

315311
template <typename T>
316312
inline void sgd_momentum_math(
@@ -460,7 +456,5 @@ REGISTER_FUSED_SGD_OP(float);
460456
REGISTER_FUSED_SGD_OP(half);
461457
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
462458
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
463-
#if __METAL_VERSION__ >= 310
464459
REGISTER_FUSED_SGD_OP(bfloat);
465460
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
466-
#endif

0 commit comments

Comments
 (0)