Skip to content

Commit dd41a39

Browse files
malfetpytorchmergebot
authored andcommitted
[MPS] Fix unary/binary ops for 2**32+ elem tensors (pytorch#155183)
By using `TensorIterator::with_32bit_indexing()` primitive Add `bind_tensors` helper function that correctly sets up MPS tensors originating from TensorIterator TODO: Add comments to bind_tensors as well asunit test, based on ``` python -c "import torch;print((torch.rand(1, 1024, 1024, dtype=torch.bfloat16, device='mps') + torch.rand(5000, 1, 1, dtype=torch.bfloat16, device='mps')).sin())" ``` Fixes pytorch#154828 Pull Request resolved: pytorch#155183 Approved by: https://github.com/cyyever, https://github.com/dcci, https://github.com/Skylion007 ghstack dependencies: pytorch#155150, pytorch#155178, pytorch#155184
1 parent 05dd638 commit dd41a39

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

aten/src/ATen/native/mps/MetalShaderLibrary.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ class MetalShaderLibrary {
154154
MTLLibrary_t lib,
155155
const std::string& fname);
156156
MTLLibrary_t compileLibrary(const std::string& src);
157+
void bind_tensors(MTLComputeCommandEncoder_t, TensorIteratorBase&);
157158
std::string shaderSource;
158159
unsigned nparams;
159160
MTLCompileOptions* compile_options;

aten/src/ATen/native/mps/OperationUtils.mm

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -971,10 +971,34 @@ static dispatch_data_t getSectionData(const std::string& name) {
971971
}
972972
};
973973

974+
void MetalShaderLibrary::bind_tensors(id<MTLComputeCommandEncoder> encoder, TensorIteratorBase& iter) {
975+
for (auto idx : c10::irange(iter.ntensors())) {
976+
auto& t = iter.tensor_base(idx);
977+
// Handle CPU scalars
978+
if (C10_UNLIKELY(t.device().type() == kCPU)) {
979+
mtl_setBuffer(encoder, t, idx);
980+
continue;
981+
}
982+
// At the moment, MPS storage data is not the real GPU pointer, but rather a pointer to id<MTLBuffer> object
983+
// But TensorIterator constructs data_ptr as if base was just a raw pointer
984+
// Workaround this problem by computing an offset from the start of the tensor, which works for both
985+
// tensor vies and sliced 64-bit iterators
986+
auto offs = reinterpret_cast<size_t>(iter.data_ptr(idx)) - reinterpret_cast<size_t>(t.storage().data());
987+
[encoder setBuffer:getMTLBufferStorage(t) offset:offs atIndex:idx];
988+
}
989+
}
990+
974991
void MetalShaderLibrary::exec_unary_kernel(TensorIteratorBase& iter,
975992
const std::string& name,
976993
std::optional<int64_t> extra) {
977-
TORCH_CHECK(iter.can_use_32bit_indexing(), name, " can't be indexed using 32-bit iterator for shape ", iter.shape());
994+
// Decompose 64-bit tensor into 32-bit ones
995+
if (!iter.can_use_32bit_indexing()) {
996+
for (auto&& sub_iter : iter.with_32bit_indexing()) {
997+
exec_unary_kernel(sub_iter, name, extra);
998+
}
999+
return;
1000+
}
1001+
9781002
auto inputTensor = iter.input(0);
9791003
auto outputTensor = iter.output(0);
9801004
uint32_t length = iter.numel();
@@ -997,7 +1021,7 @@ static dispatch_data_t getSectionData(const std::string& name) {
9971021
getMPSProfiler().beginProfileKernel(cplState, name, {inputTensor});
9981022

9991023
[computeEncoder setComputePipelineState:cplState];
1000-
mtl_setArgs(computeEncoder, outputTensor, inputTensor);
1024+
bind_tensors(computeEncoder, iter);
10011025
if (!iter.is_contiguous()) {
10021026
mtl_setArgs<2>(computeEncoder,
10031027
outputTensor.sizes(),
@@ -1022,13 +1046,20 @@ static dispatch_data_t getSectionData(const std::string& name) {
10221046
// Right now running something like 1.0-torch.rand(5, device='mps') will create iterator with
10231047
// double as common dtype (because Python floating point are always 64-bit values)
10241048
TORCH_CHECK(iter.output().scalar_type() != at::kDouble, "float64 is not supported on MPS");
1025-
TORCH_CHECK(iter.can_use_32bit_indexing(), name, " can't be indexed using 32-bit iterator for shape ", iter.shape());
10261049

10271050
// Skip for empty iterators
10281051
if (iter.numel() == 0) {
10291052
return;
10301053
}
10311054

1055+
// Decompose 64-bit tensor into 32-bit ones
1056+
if (!iter.can_use_32bit_indexing()) {
1057+
for (auto&& sub_iter : iter.with_32bit_indexing()) {
1058+
exec_binary_kernel(sub_iter, name, alpha);
1059+
}
1060+
return;
1061+
}
1062+
10321063
auto convert_double_scalar = [](Tensor& t) {
10331064
if (t.dim() != 0) {
10341065
return;
@@ -1062,7 +1093,7 @@ static dispatch_data_t getSectionData(const std::string& name) {
10621093
getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other});
10631094
[computeEncoder setComputePipelineState:binaryPSO];
10641095
// Set input and output tensors
1065-
mtl_setArgs(computeEncoder, out, input, other);
1096+
bind_tensors(computeEncoder, iter);
10661097
// Iterator is contiguous if all of its elements are dense in storage,
10671098
// i.e. it's true for both row-first and column-first tensors
10681099
if (iter.is_contiguous()) {

test/test_mps.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7955,6 +7955,20 @@ def test_inplace_bitwise_not(self, dtype):
79557955
x[::2].bitwise_not_()
79567956
self.assertEqual(x_mps.cpu(), x_cpu)
79577957

7958+
7959+
class TestLargeTensors(TestCaseMPS):
7960+
def test_64bit_binops(self):
7961+
if torch.mps.recommended_max_memory() < 16_000_000_000:
7962+
raise unittest.SkipTest("Needs at least 16Gb of RAM")
7963+
a = torch.rand(1, 1024, 1024, dtype=torch.float16, device='mps')
7964+
b = torch.rand(5000, 1, 1, dtype=torch.float16, device='mps')
7965+
rc = (a + b).sin()
7966+
slice_idx = -2
7967+
rc_slice = rc[slice_idx:]
7968+
rc_slice_cpu = (a.cpu() + b.cpu()[slice_idx:]).sin()
7969+
self.assertEqual(rc_slice, rc_slice_cpu)
7970+
7971+
79587972
class TestLogical(TestCaseMPS):
79597973
def _wrap_tensor(self, x, device="cpu", dtype=None, requires_grad=False):
79607974
return torch.tensor(x, device=device, dtype=dtype, requires_grad=requires_grad)

0 commit comments

Comments
 (0)