Skip to content

Commit 24a18d7

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Use metal shaders for all view ops (pytorch#143375)
Before this PR Metal shaders were used to scatter/gather 1-5 dimensional tensors. This PR introduces generalized ones that could be used for any dimensionality and as results gets rid of 700+ lines complex and untested code that might not even work as expected. Generalized gather shader looks as follows ```metal kernel void gather_kernel_n(uint linear_index [[thread_position_in_grid]], constant void * src_ [[buffer(0)]], device void * dst_ [[buffer(1)]], constant uint32_t * size [[buffer(2)]], constant uint32_t * stride [[buffer(3)]], constant uint32_t & numel [[buffer(4)]], constant int32_t & ndim [[buffer(5)]]) {{ if (linear_index >= numel) return; constant {0} * src = (constant {0} *)src_; device {1} * dst = (device {1} *)dst_; uint64_t src_offs = 0; auto src_idx = linear_index; for(int dim = ndim - 1; dim >= 0; --dim) {{ src_offs += stride[dim] * (src_idx % size[dim]); src_idx /= size[dim]; }} dst[linear_index] = cast<{1}>(src[src_offs]); }} ``` Which, according to the following benchmark ```python from timeit import default_timer import torch import torch.utils.cpp_extension from torch.utils.benchmark import Measurement, Timer t = Timer( stmt=f"y.copy_(x);torch.mps.synchronize()", setup=f"x=torch.rand(4, 5, 16, 64, 33, 24, dtype=torch.float32, device='mps')[:,:,:,:24,:24,];y=torch.empty(x.shape, device=x.device, dtype=x.dtype)", language="python", timer=default_timer ) print(t.blocked_autorange()) ``` Is almost twice as fast as previous implementation (i.e. on Mac Book M2 Pro it returns 2.9ms for MPS version vs 1.5ms for shader one On MacOS Sequoia [`gatherWithUpdatesTensor: indicesTensor:...`](https://developer.apple.com/documentation/metalperformanceshadersgraph/mpsgraph/gather(withupdatestensor:indicestensor:axis:batchdimensions:name:)?language=objc) crashes if invoked with complex data type, as one can see by running the code below ```swift import Metal import MetalPerformanceShadersGraph func gatherComplexMPS(device: MTLDevice, inp_buf: MTLBuffer, idx_buf: MTLBuffer, out_buf: MTLBuffer, inp_elem: Int, upd_elem: Int) { let graph = MPSGraph() let inputPlaceholder = graph.placeholder(shape: [inp_elem as NSNumber], dataType: .complexFloat32, name: nil) let indicesPlaceholder = graph.placeholder(shape: [upd_elem as NSNumber], dataType: .int64, name: nil) let outNode = graph.gather(withUpdatesTensor: inputPlaceholder, indicesTensor: indicesPlaceholder, axis: 0, batchDimensions: 0, name: nil) let mpsInputBuffer = MPSGraphTensorData(inp_buf, shape: [inp_elem as NSNumber], dataType: .complexFloat32) let mpsIndicesBuffer = MPSGraphTensorData(idx_buf, shape: [upd_elem as NSNumber], dataType: .int64) let mpsOutputBuffer = MPSGraphTensorData(out_buf, shape: [inp_elem as NSNumber], dataType: .complexFloat32) guard let queue = device.makeCommandQueue() else { fatalError("Can't make queue") } graph.run(with: queue, feeds: [inputPlaceholder: mpsInputBuffer, indicesPlaceholder: mpsIndicesBuffer ], targetOperations: nil, resultsDictionary: [outNode: mpsOutputBuffer]) } func makeBufferWithValues<T>(device: MTLDevice, values: [T]) -> MTLBuffer { guard let buf = device.makeBuffer(length: values.count * MemoryLayout<T>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") } let buf_data = buf.contents().assumingMemoryBound(to: T.self) for i in 0..<values.count { buf_data[i] = values[i] } return buf } guard let device = MTLCopyAllDevices().first else { fatalError("Not Metal device found") } print("Using device \(device.name)") let inp_buf = makeBufferWithValues(device: device, values: [1.0, 2.0 , 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) let idx_buf = makeBufferWithValues(device: device, values: [0, 1, 2, 3]) guard let out_buf = device.makeBuffer(length:8 * MemoryLayout<Float>.size, options: [.storageModeShared]) else { fatalError("Can't alloc") } gatherComplexMPS(device: device, inp_buf: inp_buf, idx_buf: idx_buf, out_buf: out_buf, inp_elem: 4, upd_elem: 4) ``` Fixes pytorch#143140 Pull Request resolved: pytorch#143375 Approved by: https://github.com/albanD
1 parent f47aac6 commit 24a18d7

File tree

4 files changed

+64
-757
lines changed

4 files changed

+64
-757
lines changed

aten/src/ATen/mps/IndexKernels.h

Lines changed: 29 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,6 @@
33
namespace at::mps {
44

55
static const char* SCATTER_OPS_TEMPLATE = R"METAL_SCATTER(
6-
struct __attribute__ ((packed)) packed_uint5{{
7-
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
8-
}};
9-
106
template<typename Y, typename X>
117
Y cast(const X x);
128
@@ -15,32 +11,26 @@ template<>
1511
return {2};
1612
}}
1713
18-
kernel void scatter_kernel_5(uint linear_index [[thread_position_in_grid]],
19-
constant void * src_ [[buffer(0)]],
20-
device void * dst_ [[buffer(1)]],
21-
constant packed_uint5 & size [[buffer(2)]],
22-
constant packed_uint5 & stride [[buffer(3)]],
23-
constant uint32_t & numel [[buffer(4)]]) {{
14+
kernel void scatter_kernel_n(uint linear_index [[thread_position_in_grid]],
15+
constant void * src_ [[buffer(0)]],
16+
device void * dst_ [[buffer(1)]],
17+
constant uint32_t * size [[buffer(2)]],
18+
constant uint32_t * stride [[buffer(3)]],
19+
constant uint32_t & numel [[buffer(4)]],
20+
constant int32_t & ndim [[buffer(5)]]) {{
2421
if (linear_index >= numel) return;
2522
2623
constant {0} * src = (constant {0} *)src_;
2724
device {1} * dst = (device {1} *)dst_;
2825
29-
packed_uint5 local_index;
30-
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
31-
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
32-
local_index.z = linear_index / (size.u * size.w) % size.z;
33-
local_index.w = linear_index / size.u % size.w;
34-
local_index.u = linear_index % size.u;
35-
36-
packed_uint5 strided_index;
37-
strided_index.x = local_index.x * stride.x;
38-
strided_index.y = local_index.y * stride.y;
39-
strided_index.z = local_index.z * stride.z;
40-
strided_index.w = local_index.w * stride.w;
41-
strided_index.u = local_index.u * stride.u;
42-
43-
dst[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u] = cast<{1}>(src[linear_index]);
26+
uint64_t dst_offs = 0;
27+
auto dst_idx = linear_index;
28+
for(int dim = ndim - 1; dim >= 0; --dim) {{
29+
dst_offs += stride[dim] * (dst_idx % size[dim]);
30+
dst_idx /= size[dim];
31+
}}
32+
33+
dst[dst_offs] = cast<{1}>(src[linear_index]);
4434
}}
4535
4636
kernel void scatter_kernel_4(uint linear_index [[thread_position_in_grid]],
@@ -121,10 +111,6 @@ kernel void scatter_kernel_1(uint linear_index [[thread_position_in
121111
)METAL_SCATTER";
122112

123113
static const char* GATHER_OPS_TEMPLATE = R"METAL_GATHER(
124-
struct __attribute__ ((packed)) packed_uint5{{
125-
uint32_t x; uint32_t y; uint32_t z; uint32_t w; uint32_t u;
126-
}};
127-
128114
template<typename Y, typename X>
129115
Y cast(const X x);
130116
@@ -133,33 +119,26 @@ template<>
133119
return {2};
134120
}}
135121
136-
kernel void gather_kernel_5(uint linear_index [[thread_position_in_grid]],
137-
constant void * src_ [[buffer(0)]],
138-
device void * dst_ [[buffer(1)]],
139-
constant packed_uint5 & size [[buffer(2)]],
140-
constant packed_uint5 & stride [[buffer(3)]],
141-
constant uint32_t & numel [[buffer(4)]]) {{
122+
kernel void gather_kernel_n(uint linear_index [[thread_position_in_grid]],
123+
constant void * src_ [[buffer(0)]],
124+
device void * dst_ [[buffer(1)]],
125+
constant uint32_t * size [[buffer(2)]],
126+
constant uint32_t * stride [[buffer(3)]],
127+
constant uint32_t & numel [[buffer(4)]],
128+
constant int32_t & ndim [[buffer(5)]]) {{
142129
if (linear_index >= numel) return;
143130
144131
constant {0} * src = (constant {0} *)src_;
145132
device {1} * dst = (device {1} *)dst_;
146133
134+
uint64_t src_offs = 0;
135+
auto src_idx = linear_index;
136+
for(int dim = ndim - 1; dim >= 0; --dim) {{
137+
src_offs += stride[dim] * (src_idx % size[dim]);
138+
src_idx /= size[dim];
139+
}}
147140
148-
packed_uint5 local_index;
149-
local_index.x = linear_index / (size.u * size.w * size.z * size.y) % size.x;
150-
local_index.y = linear_index / (size.u * size.w * size.z) % size.y;
151-
local_index.z = linear_index / (size.u * size.w) % size.z;
152-
local_index.w = linear_index / size.u % size.w;
153-
local_index.u = linear_index % size.u;
154-
155-
packed_uint5 strided_index;
156-
strided_index.x = local_index.x * stride.x;
157-
strided_index.y = local_index.y * stride.y;
158-
strided_index.z = local_index.z * stride.z;
159-
strided_index.w = local_index.w * stride.w;
160-
strided_index.u = local_index.u * stride.u;
161-
162-
dst[linear_index] = cast<{1}>(src[strided_index.x + strided_index.y + strided_index.z + strided_index.w + strided_index.u]);
141+
dst[linear_index] = cast<{1}>(src[src_offs]);
163142
}}
164143
165144
kernel void gather_kernel_4(uint linear_index [[thread_position_in_grid]],

aten/src/ATen/native/mps/OperationUtils.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,6 @@ std::string getArrayRefString(const IntArrayRef s);
8181
// use has_storage() on the returned tensor to determine if src actually is a view
8282
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
8383
Tensor& scatterViewTensor(const Tensor& src, Tensor& output);
84-
bool canSliceViewTensor(const TensorBase& src, MPSShape* mpsShape);
85-
MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src,
86-
MPSShape* mpsShape,
87-
const MPSDataType mpsDataType);
8884
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
8985
MPSGraphTensor* inputTensor,
9086
const TensorBase& input,

0 commit comments

Comments
 (0)