@@ -331,13 +331,13 @@ __device__ float3 vertexInterp(float isolevel, float3 p1, float3 p2, float valp1
331331}
332332
333333__global__ void mcubes_cuda_kernel (
334- const torch::PackedTensorAccessor32 <float , 3 , torch::RestrictPtrTraits> vol,
335- torch::PackedTensorAccessor32 <float , 5 , torch::RestrictPtrTraits> vertices,
336- torch::PackedTensorAccessor32 <int , 3 , torch::RestrictPtrTraits> ntris_in_cells,
334+ const torch::PackedTensorAccessor64 <float , 3 , torch::RestrictPtrTraits> vol,
335+ torch::PackedTensorAccessor64 <float , 5 , torch::RestrictPtrTraits> vertices,
336+ torch::PackedTensorAccessor64 <int , 3 , torch::RestrictPtrTraits> ntris_in_cells,
337337 int3 nGrids,
338338 float threshold,
339- const torch::PackedTensorAccessor32 <int , 1 , torch::RestrictPtrTraits> edgeTable,
340- const torch::PackedTensorAccessor32 <int , 2 , torch::RestrictPtrTraits> triTable) {
339+ const torch::PackedTensorAccessor64 <int , 1 , torch::RestrictPtrTraits> edgeTable,
340+ const torch::PackedTensorAccessor64 <int , 2 , torch::RestrictPtrTraits> triTable) {
341341
342342 const int ix = blockIdx .x * blockDim .x + threadIdx .x ;
343343 const int iy = blockIdx .y * blockDim .y + threadIdx .y ;
@@ -436,12 +436,12 @@ __global__ void mcubes_cuda_kernel(
436436}
437437
438438__global__ void compaction (
439- const torch::PackedTensorAccessor32 <float , 5 , torch::RestrictPtrTraits> vertBuf,
440- const torch::PackedTensorAccessor32 <int , 3 , torch::RestrictPtrTraits> ntris,
441- const torch::PackedTensorAccessor32 <int , 3 , torch::RestrictPtrTraits> offsets,
439+ const torch::PackedTensorAccessor64 <float , 5 , torch::RestrictPtrTraits> vertBuf,
440+ const torch::PackedTensorAccessor64 <int , 3 , torch::RestrictPtrTraits> ntris,
441+ const torch::PackedTensorAccessor64 <int , 3 , torch::RestrictPtrTraits> offsets,
442442 int3 nGrids,
443- torch::PackedTensorAccessor32 <float , 2 , torch::RestrictPtrTraits> verts,
444- torch::PackedTensorAccessor32 <int , 2 , torch::RestrictPtrTraits> faces) {
443+ torch::PackedTensorAccessor64 <float , 2 , torch::RestrictPtrTraits> verts,
444+ torch::PackedTensorAccessor64 <int , 2 , torch::RestrictPtrTraits> faces) {
445445
446446 const int ix = blockIdx .x * blockDim .x + threadIdx .x ;
447447 const int iy = blockIdx .y * blockDim .y + threadIdx .y ;
@@ -521,13 +521,13 @@ std::vector<torch::Tensor> mcubes_cuda(torch::Tensor vol, float threshold) {
521521 // Kernel call
522522 cudaSetDevice (deviceId);
523523 mcubes_cuda_kernel<<<blocks, threads, 0 , stream>>> (
524- vol.packed_accessor32 <float , 3 , torch::RestrictPtrTraits>(),
525- vert_buffer.packed_accessor32 <float , 5 , torch::RestrictPtrTraits>(),
526- ntris_in_cells.packed_accessor32 <int , 3 , torch::RestrictPtrTraits>(),
524+ vol.packed_accessor64 <float , 3 , torch::RestrictPtrTraits>(),
525+ vert_buffer.packed_accessor64 <float , 5 , torch::RestrictPtrTraits>(),
526+ ntris_in_cells.packed_accessor64 <int , 3 , torch::RestrictPtrTraits>(),
527527 nGrids,
528528 threshold,
529- edgeTableTensorCuda.packed_accessor32 <int , 1 , torch::RestrictPtrTraits>(),
530- triTableTensorCuda.packed_accessor32 <int , 2 , torch::RestrictPtrTraits>()
529+ edgeTableTensorCuda.packed_accessor64 <int , 1 , torch::RestrictPtrTraits>(),
530+ triTableTensorCuda.packed_accessor64 <int , 2 , torch::RestrictPtrTraits>()
531531 );
532532 cudaDeviceSynchronize ();
533533
@@ -549,12 +549,12 @@ std::vector<torch::Tensor> mcubes_cuda(torch::Tensor vol, float threshold) {
549549
550550 cudaSetDevice (deviceId);
551551 compaction<<<blocks, threads, 0 , stream>>> (
552- vert_buffer.packed_accessor32 <float , 5 , torch::RestrictPtrTraits>(),
553- ntris_in_cells.packed_accessor32 <int , 3 , torch::RestrictPtrTraits>(),
554- offsets.packed_accessor32 <int , 3 , torch::RestrictPtrTraits>(),
552+ vert_buffer.packed_accessor64 <float , 5 , torch::RestrictPtrTraits>(),
553+ ntris_in_cells.packed_accessor64 <int , 3 , torch::RestrictPtrTraits>(),
554+ offsets.packed_accessor64 <int , 3 , torch::RestrictPtrTraits>(),
555555 nGrids,
556- verts.packed_accessor32 <float , 2 , torch::RestrictPtrTraits>(),
557- faces.packed_accessor32 <int , 2 , torch::RestrictPtrTraits>()
556+ verts.packed_accessor64 <float , 2 , torch::RestrictPtrTraits>(),
557+ faces.packed_accessor64 <int , 2 , torch::RestrictPtrTraits>()
558558 );
559559 cudaDeviceSynchronize ();
560560
0 commit comments