From 3388d3f0aa6bc44fe704fca78d11743a0fcac38c Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Tue, 20 Dec 2022 04:07:04 -0800 Subject: [PATCH] Windows fix for marching cubes #1398 Summary: See https://github.com/facebookresearch/pytorch3d/issues/1398 . Reviewed By: davidsonic Differential Revision: D42139493 fbshipit-source-id: 972fc33b9c3017554ce704f2f10190eba406b7c8 --- .../csrc/marching_cubes/marching_cubes.cu | 48 +++++++++---------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/pytorch3d/csrc/marching_cubes/marching_cubes.cu b/pytorch3d/csrc/marching_cubes/marching_cubes.cu index e596c43ec..527bced5d 100644 --- a/pytorch3d/csrc/marching_cubes/marching_cubes.cu +++ b/pytorch3d/csrc/marching_cubes/marching_cubes.cu @@ -13,7 +13,6 @@ #include #include #include "marching_cubes/tables.h" -#include "utils/pytorch3d_cutils.h" /* Parallelized marching cubes for pytorch extension @@ -267,13 +266,12 @@ __global__ void CompactVoxelsKernel( // isolevel: threshold to determine isosurface intersection // __global__ void GenerateFacesKernel( - torch::PackedTensorAccessor32 verts, - torch::PackedTensorAccessor faces, - torch::PackedTensorAccessor ids, - torch::PackedTensorAccessor32 + at::PackedTensorAccessor32 verts, + at::PackedTensorAccessor faces, + at::PackedTensorAccessor ids, + at::PackedTensorAccessor32 compactedVoxelArray, - torch::PackedTensorAccessor32 - numVertsScanned, + at::PackedTensorAccessor32 numVertsScanned, const uint activeVoxels, const at::PackedTensorAccessor32 vol, const at::PackedTensorAccessor32 faceTable, @@ -436,15 +434,15 @@ std::tuple MarchingCubesCuda( cudaStream_t stream = at::cuda::getCurrentCUDAStream(); // transfer _FACE_TABLE data to device - torch::Tensor face_table_tensor = torch::zeros( - {256, 16}, torch::TensorOptions().dtype(at::kInt).device(at::kCPU)); + at::Tensor face_table_tensor = at::zeros( + {256, 16}, at::TensorOptions().dtype(at::kInt).device(at::kCPU)); auto face_table_a = face_table_tensor.accessor(); for (int i = 0; i < 256; i++) { for (int j = 0; j < 16; j++) { face_table_a[i][j] = _FACE_TABLE[i][j]; } } - torch::Tensor faceTable = face_table_tensor.to(vol.device()); + at::Tensor faceTable = face_table_tensor.to(vol.device()); // get numVoxels int threads = 128; @@ -458,10 +456,10 @@ std::tuple MarchingCubesCuda( } auto d_voxelVerts = - torch::zeros({numVoxels}, torch::TensorOptions().dtype(at::kInt)) + at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt)) .to(vol.device()); auto d_voxelOccupied = - torch::zeros({numVoxels}, torch::TensorOptions().dtype(at::kInt)) + at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt)) .to(vol.device()); // Execute "ClassifyVoxelKernel" kernel to precompute @@ -480,7 +478,7 @@ std::tuple MarchingCubesCuda( // If the number of active voxels is 0, return zero tensor for verts and // faces. auto d_voxelOccupiedScan = - torch::zeros({numVoxels}, torch::TensorOptions().dtype(at::kInt)) + at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt)) .to(vol.device()); ThrustScanWrapper( d_voxelOccupiedScan.data_ptr(), @@ -493,23 +491,21 @@ std::tuple MarchingCubesCuda( int activeVoxels = lastElement + lastScan; const int device_id = vol.device().index(); - auto opt = - torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA, device_id); - auto opt_long = torch::TensorOptions() - .dtype(torch::kInt64) - .device(torch::kCUDA, device_id); + auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id); + auto opt_long = + at::TensorOptions().dtype(at::kLong).device(at::kCUDA, device_id); if (activeVoxels == 0) { int ntris = 0; - torch::Tensor verts = torch::zeros({ntris * 3, 3}, vol.options()); - torch::Tensor faces = torch::zeros({ntris, 3}, opt_long); - torch::Tensor ids = torch::zeros({ntris}, opt_long); + at::Tensor verts = at::zeros({ntris * 3, 3}, vol.options()); + at::Tensor faces = at::zeros({ntris, 3}, opt_long); + at::Tensor ids = at::zeros({ntris}, opt_long); return std::make_tuple(verts, faces, ids); } // Execute "CompactVoxelsKernel" kernel to compress voxels for accleration. // This allows us to run triangle generation on only the occupied voxels. - auto d_compVoxelArray = torch::zeros({activeVoxels}, opt); + auto d_compVoxelArray = at::zeros({activeVoxels}, opt); CompactVoxelsKernel<<>>( d_compVoxelArray.packed_accessor32(), d_voxelOccupied.packed_accessor32(), @@ -519,7 +515,7 @@ std::tuple MarchingCubesCuda( cudaDeviceSynchronize(); // Scan d_voxelVerts array to generate offsets of vertices for each voxel - auto d_voxelVertsScan = torch::zeros({numVoxels}, opt); + auto d_voxelVertsScan = at::zeros({numVoxels}, opt); ThrustScanWrapper( d_voxelVertsScan.data_ptr(), d_voxelVerts.data_ptr(), @@ -533,10 +529,10 @@ std::tuple MarchingCubesCuda( // Execute "GenerateFacesKernel" kernel // This runs only on the occupied voxels. // It looks up the field values and generates the triangle data. - torch::Tensor verts = torch::zeros({totalVerts, 3}, vol.options()); - torch::Tensor faces = torch::zeros({totalVerts / 3, 3}, opt_long); + at::Tensor verts = at::zeros({totalVerts, 3}, vol.options()); + at::Tensor faces = at::zeros({totalVerts / 3, 3}, opt_long); - torch::Tensor ids = torch::zeros({totalVerts}, opt_long); + at::Tensor ids = at::zeros({totalVerts}, opt_long); dim3 grid2((activeVoxels + threads - 1) / threads, 1, 1); if (grid2.x > 65535) {