Skip to content

Commit 2693efc

Browse files
committed
fix segment coo indexing
1 parent 4c4a2e6 commit 2693efc

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

csrc/cuda/segment_coo_cuda.cu

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ __global__ void segment_coo_broadcast_kernel(
8585

8686
int D = index_info.sizes[index_info.dims - 1];
8787
int E_1 = E / D;
88-
int E_2 = D + TB - (D % TB);
88+
int E_2 = (D - 1) + TB - ((D - 1) % TB);
8989

9090
int row_idx = blockIdx.x * blockDim.y + threadIdx.y;
9191
int col_idx = blockIdx.y * blockDim.x + threadIdx.x;
@@ -215,6 +215,12 @@ segment_coo_cuda(torch::Tensor src, torch::Tensor index,
215215
auto N = out.size(dim);
216216
auto avg_len = (float)E_2 / (float)N;
217217

218+
std::cout << "E " << E << std::endl;
219+
std::cout << "E2 " << E_2 << std::endl;
220+
std::cout << "E1 " << E_1 << std::endl;
221+
std::cout << "K " << K << std::endl;
222+
std::cout << "N " << N << std::endl;
223+
218224
auto index_info = at::cuda::detail::getTensorInfo<int64_t, int>(index);
219225
auto stream = at::cuda::getCurrentCUDAStream();
220226
AT_DISPATCH_ALL_TYPES(src.scalar_type(), "segment_coo_kernel", [&] {

0 commit comments

Comments
 (0)