Skip to content

Commit

Permalink
marching_cubes type fix
Browse files Browse the repository at this point in the history
Summary: fixes #1679

Reviewed By: MichaelRamamonjisoa

Differential Revision: D50949933

fbshipit-source-id: 5c467de8bf84dd2a3d61748b3846678582d24ea3
  • Loading branch information
bottler authored and facebook-github-bot committed Nov 14, 2023
1 parent 2f11ddc commit f613682
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
12 changes: 6 additions & 6 deletions pytorch3d/csrc/marching_cubes/marching_cubes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ __global__ void CompactVoxelsKernel(
compactedVoxelArray,
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
voxelOccupied,
const at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
const at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits>
voxelOccupiedScan,
uint numVoxels) {
uint id = blockIdx.x * blockDim.x + threadIdx.x;
Expand Down Expand Up @@ -255,7 +255,7 @@ __global__ void GenerateFacesKernel(
at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
compactedVoxelArray,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> numVertsScanned,
at::PackedTensorAccessor32<long, 1, at::RestrictPtrTraits> numVertsScanned,
const uint activeVoxels,
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
Expand Down Expand Up @@ -471,7 +471,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
auto d_voxelOccupiedScan_ = d_voxelOccupiedScan.index({Slice(1, None)});

// number of active voxels
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<int>();
int activeVoxels = d_voxelOccupiedScan[numVoxels].cpu().item<long>();

const int device_id = vol.device().index();
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
Expand All @@ -492,7 +492,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupiedScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
numVoxels);
AT_CUDA_CHECK(cudaGetLastError());
cudaDeviceSynchronize();
Expand All @@ -502,7 +502,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
auto d_voxelVertsScan_ = d_voxelVertsScan.index({Slice(1, None)});

// total number of vertices
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<int>();
int totalVerts = d_voxelVertsScan[numVoxels].cpu().item<long>();

// Execute "GenerateFacesKernel" kernel
// This runs only on the occupied voxels.
Expand All @@ -522,7 +522,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
faces.packed_accessor<int64_t, 2, at::RestrictPtrTraits>(),
ids.packed_accessor<int64_t, 1, at::RestrictPtrTraits>(),
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelVertsScan_.packed_accessor32<long, 1, at::RestrictPtrTraits>(),
activeVoxels,
vol.packed_accessor32<float, 3, at::RestrictPtrTraits>(),
faceTable.packed_accessor32<int, 2, at::RestrictPtrTraits>(),
Expand Down
7 changes: 5 additions & 2 deletions tests/test_marching_cubes.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,8 +939,11 @@ def test_ball_example(self):
u = u[None].float()
verts, faces = marching_cubes_naive(u, 0, return_local_coords=False)
verts2, faces2 = marching_cubes(u, 0, return_local_coords=False)
self.assertClose(verts[0], verts2[0])
self.assertClose(faces[0], faces2[0])
self.assertClose(verts2[0], verts[0])
self.assertClose(faces2[0], faces[0])
verts3, faces3 = marching_cubes(u.cuda(), 0, return_local_coords=False)
self.assertEqual(len(verts3), len(verts))
self.assertEqual(len(faces3), len(faces))

@staticmethod
def marching_cubes_with_init(algo_type: str, batch_size: int, V: int, device: str):
Expand Down

0 comments on commit f613682

Please sign in to comment.