Skip to content

Commit ed2e92b

Browse files
malfetpytorchmergebot
authored andcommitted
[BE][MPS] Don't pass nnz to mark_segments (pytorch#170403)
Fixes following unused variable warning ``` /Users/malfet/git/pytorch/pytorch/aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal:288:27: warning: unused parameter 'nnz' [-Wunused-parameter] constant uint& nnz [[buffer(2)]], ``` Also, use short circuit language rule to make kernel more compact Pull Request resolved: pytorch#170403 Approved by: https://github.com/Skylion007
1 parent 494bce3 commit ed2e92b

File tree

2 files changed

+4
-9
lines changed

2 files changed

+4
-9
lines changed

aten/src/ATen/native/sparse/mps/SparseMPSTensorMath.mm

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,7 +1409,7 @@ static Tensor softmax_sparse_mps_impl(
14091409
auto pso = lib.getPipelineStateForFunc("mark_segments");
14101410
auto enc = stream->commandEncoder();
14111411
[enc setComputePipelineState:pso];
1412-
mtl_setArgs(enc, sorted_pool_indices, mask, nnz_u);
1412+
mtl_setArgs(enc, sorted_pool_indices, mask);
14131413

14141414
auto gridSize = MTLSizeMake(nnz, 1, 1);
14151415
auto threadGroupSize = MTLSizeMake(std::min<uint64_t>(nnz, pso.maxTotalThreadsPerThreadgroup), 1, 1);
@@ -1522,7 +1522,7 @@ static Tensor softmax_backward_sparse_mps_impl(
15221522
auto pso = lib.getPipelineStateForFunc("mark_segments");
15231523
auto enc = stream->commandEncoder();
15241524
[enc setComputePipelineState:pso];
1525-
mtl_setArgs(enc, sorted_pool_indices, mask, nnz_u);
1525+
mtl_setArgs(enc, sorted_pool_indices, mask);
15261526
auto gridSize = MTLSizeMake(nnz, 1, 1);
15271527
auto threadGroupSize = MTLSizeMake(std::min<uint64_t>(nnz, pso.maxTotalThreadsPerThreadgroup), 1, 1);
15281528
[enc dispatchThreads:gridSize threadsPerThreadgroup:threadGroupSize];
@@ -1592,4 +1592,4 @@ Tensor log_softmax_backward_sparse_mps(const Tensor& grad, const Tensor& output,
15921592

15931593
REGISTER_MPS_DISPATCH(sparse_mask_intersection_out_stub, &sparse_mask_intersection_out_mps_kernel);
15941594
REGISTER_MPS_DISPATCH(sparse_mask_projection_out_stub, &sparse_mask_projection_out_mps_kernel);
1595-
} // namespace at::native
1595+
} // namespace at::native

aten/src/ATen/native/sparse/mps/kernels/SparseTensorMath.metal

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,14 +285,9 @@ kernel void spmm_addmm_coo(
285285
kernel void mark_segments(
286286
device const int64_t* indices [[buffer(0)]],
287287
device int* mask [[buffer(1)]],
288-
constant uint& nnz [[buffer(2)]],
289288
uint tid [[thread_position_in_grid]])
290289
{
291-
if (tid == 0) {
292-
mask[0] = 1;
293-
} else {
294-
mask[tid] = (indices[tid] != indices[tid - 1]) ? 1 : 0;
295-
}
290+
mask[tid] = (tid == 0 || indices[tid] != indices[tid - 1]) ? 1 : 0;
296291
}
297292

298293
kernel void compute_offsets_and_counts(

0 commit comments

Comments
 (0)