Skip to content

Commit 05a3fab

Browse files
authored
Merge pull request #11 from isl-org/yuanxion/fix-input-numel
I'm sorry I did not reply you so long, but I have marged your PR!
2 parents 72b9129 + bc36b08 commit 05a3fab

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

cxx/mcubes_cuda.cu

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,13 @@ __device__ float3 vertexInterp(float isolevel, float3 p1, float3 p2, float valp1
331331
}
332332

333333
__global__ void mcubes_cuda_kernel(
334-
const torch::PackedTensorAccessor32<float, 3, torch::RestrictPtrTraits> vol,
335-
torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits> vertices,
336-
torch::PackedTensorAccessor32<int, 3, torch::RestrictPtrTraits> ntris_in_cells,
334+
const torch::PackedTensorAccessor64<float, 3, torch::RestrictPtrTraits> vol,
335+
torch::PackedTensorAccessor64<float, 5, torch::RestrictPtrTraits> vertices,
336+
torch::PackedTensorAccessor64<int, 3, torch::RestrictPtrTraits> ntris_in_cells,
337337
int3 nGrids,
338338
float threshold,
339-
const torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits> edgeTable,
340-
const torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> triTable) {
339+
const torch::PackedTensorAccessor64<int, 1, torch::RestrictPtrTraits> edgeTable,
340+
const torch::PackedTensorAccessor64<int, 2, torch::RestrictPtrTraits> triTable) {
341341

342342
const int ix = blockIdx.x * blockDim.x + threadIdx.x;
343343
const int iy = blockIdx.y * blockDim.y + threadIdx.y;
@@ -436,12 +436,12 @@ __global__ void mcubes_cuda_kernel(
436436
}
437437

438438
__global__ void compaction(
439-
const torch::PackedTensorAccessor32<float, 5, torch::RestrictPtrTraits> vertBuf,
440-
const torch::PackedTensorAccessor32<int, 3, torch::RestrictPtrTraits> ntris,
441-
const torch::PackedTensorAccessor32<int, 3, torch::RestrictPtrTraits> offsets,
439+
const torch::PackedTensorAccessor64<float, 5, torch::RestrictPtrTraits> vertBuf,
440+
const torch::PackedTensorAccessor64<int, 3, torch::RestrictPtrTraits> ntris,
441+
const torch::PackedTensorAccessor64<int, 3, torch::RestrictPtrTraits> offsets,
442442
int3 nGrids,
443-
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> verts,
444-
torch::PackedTensorAccessor32<int, 2, torch::RestrictPtrTraits> faces) {
443+
torch::PackedTensorAccessor64<float, 2, torch::RestrictPtrTraits> verts,
444+
torch::PackedTensorAccessor64<int, 2, torch::RestrictPtrTraits> faces) {
445445

446446
const int ix = blockIdx.x * blockDim.x + threadIdx.x;
447447
const int iy = blockIdx.y * blockDim.y + threadIdx.y;
@@ -521,13 +521,13 @@ std::vector<torch::Tensor> mcubes_cuda(torch::Tensor vol, float threshold) {
521521
// Kernel call
522522
cudaSetDevice(deviceId);
523523
mcubes_cuda_kernel<<<blocks, threads, 0, stream>>>(
524-
vol.packed_accessor32<float, 3, torch::RestrictPtrTraits>(),
525-
vert_buffer.packed_accessor32<float, 5, torch::RestrictPtrTraits>(),
526-
ntris_in_cells.packed_accessor32<int, 3, torch::RestrictPtrTraits>(),
524+
vol.packed_accessor64<float, 3, torch::RestrictPtrTraits>(),
525+
vert_buffer.packed_accessor64<float, 5, torch::RestrictPtrTraits>(),
526+
ntris_in_cells.packed_accessor64<int, 3, torch::RestrictPtrTraits>(),
527527
nGrids,
528528
threshold,
529-
edgeTableTensorCuda.packed_accessor32<int, 1, torch::RestrictPtrTraits>(),
530-
triTableTensorCuda.packed_accessor32<int, 2, torch::RestrictPtrTraits>()
529+
edgeTableTensorCuda.packed_accessor64<int, 1, torch::RestrictPtrTraits>(),
530+
triTableTensorCuda.packed_accessor64<int, 2, torch::RestrictPtrTraits>()
531531
);
532532
cudaDeviceSynchronize();
533533

@@ -549,12 +549,12 @@ std::vector<torch::Tensor> mcubes_cuda(torch::Tensor vol, float threshold) {
549549

550550
cudaSetDevice(deviceId);
551551
compaction<<<blocks, threads, 0, stream>>>(
552-
vert_buffer.packed_accessor32<float, 5, torch::RestrictPtrTraits>(),
553-
ntris_in_cells.packed_accessor32<int, 3, torch::RestrictPtrTraits>(),
554-
offsets.packed_accessor32<int, 3, torch::RestrictPtrTraits>(),
552+
vert_buffer.packed_accessor64<float, 5, torch::RestrictPtrTraits>(),
553+
ntris_in_cells.packed_accessor64<int, 3, torch::RestrictPtrTraits>(),
554+
offsets.packed_accessor64<int, 3, torch::RestrictPtrTraits>(),
555555
nGrids,
556-
verts.packed_accessor32<float, 2, torch::RestrictPtrTraits>(),
557-
faces.packed_accessor32<int, 2, torch::RestrictPtrTraits>()
556+
verts.packed_accessor64<float, 2, torch::RestrictPtrTraits>(),
557+
faces.packed_accessor64<int, 2, torch::RestrictPtrTraits>()
558558
);
559559
cudaDeviceSynchronize();
560560

0 commit comments

Comments
 (0)