Skip to content

Commit

Permalink
Windows fix for marching cubes #1398
Browse files Browse the repository at this point in the history
Summary: See #1398 .

Reviewed By: davidsonic

Differential Revision: D42139493

fbshipit-source-id: 972fc33b9c3017554ce704f2f10190eba406b7c8
  • Loading branch information
bottler authored and facebook-github-bot committed Dec 20, 2022
1 parent 3145dd4 commit 3388d3f
Showing 1 changed file with 22 additions and 26 deletions.
48 changes: 22 additions & 26 deletions pytorch3d/csrc/marching_cubes/marching_cubes.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <thrust/scan.h>
#include <cstdio>
#include "marching_cubes/tables.h"
#include "utils/pytorch3d_cutils.h"

/*
Parallelized marching cubes for pytorch extension
Expand Down Expand Up @@ -267,13 +266,12 @@ __global__ void CompactVoxelsKernel(
// isolevel: threshold to determine isosurface intersection
//
__global__ void GenerateFacesKernel(
torch::PackedTensorAccessor32<float, 2, torch::RestrictPtrTraits> verts,
torch::PackedTensorAccessor<int64_t, 2, torch::RestrictPtrTraits> faces,
torch::PackedTensorAccessor<int64_t, 1, torch::RestrictPtrTraits> ids,
torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits>
at::PackedTensorAccessor32<float, 2, at::RestrictPtrTraits> verts,
at::PackedTensorAccessor<int64_t, 2, at::RestrictPtrTraits> faces,
at::PackedTensorAccessor<int64_t, 1, at::RestrictPtrTraits> ids,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits>
compactedVoxelArray,
torch::PackedTensorAccessor32<int, 1, torch::RestrictPtrTraits>
numVertsScanned,
at::PackedTensorAccessor32<int, 1, at::RestrictPtrTraits> numVertsScanned,
const uint activeVoxels,
const at::PackedTensorAccessor32<float, 3, at::RestrictPtrTraits> vol,
const at::PackedTensorAccessor32<int, 2, at::RestrictPtrTraits> faceTable,
Expand Down Expand Up @@ -436,15 +434,15 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// transfer _FACE_TABLE data to device
torch::Tensor face_table_tensor = torch::zeros(
{256, 16}, torch::TensorOptions().dtype(at::kInt).device(at::kCPU));
at::Tensor face_table_tensor = at::zeros(
{256, 16}, at::TensorOptions().dtype(at::kInt).device(at::kCPU));
auto face_table_a = face_table_tensor.accessor<int, 2>();
for (int i = 0; i < 256; i++) {
for (int j = 0; j < 16; j++) {
face_table_a[i][j] = _FACE_TABLE[i][j];
}
}
torch::Tensor faceTable = face_table_tensor.to(vol.device());
at::Tensor faceTable = face_table_tensor.to(vol.device());

// get numVoxels
int threads = 128;
Expand All @@ -458,10 +456,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
}

auto d_voxelVerts =
torch::zeros({numVoxels}, torch::TensorOptions().dtype(at::kInt))
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
.to(vol.device());
auto d_voxelOccupied =
torch::zeros({numVoxels}, torch::TensorOptions().dtype(at::kInt))
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
.to(vol.device());

// Execute "ClassifyVoxelKernel" kernel to precompute
Expand All @@ -480,7 +478,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
// If the number of active voxels is 0, return zero tensor for verts and
// faces.
auto d_voxelOccupiedScan =
torch::zeros({numVoxels}, torch::TensorOptions().dtype(at::kInt))
at::zeros({numVoxels}, at::TensorOptions().dtype(at::kInt))
.to(vol.device());
ThrustScanWrapper(
d_voxelOccupiedScan.data_ptr<int>(),
Expand All @@ -493,23 +491,21 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
int activeVoxels = lastElement + lastScan;

const int device_id = vol.device().index();
auto opt =
torch::TensorOptions().dtype(torch::kInt).device(torch::kCUDA, device_id);
auto opt_long = torch::TensorOptions()
.dtype(torch::kInt64)
.device(torch::kCUDA, device_id);
auto opt = at::TensorOptions().dtype(at::kInt).device(at::kCUDA, device_id);
auto opt_long =
at::TensorOptions().dtype(at::kLong).device(at::kCUDA, device_id);

if (activeVoxels == 0) {
int ntris = 0;
torch::Tensor verts = torch::zeros({ntris * 3, 3}, vol.options());
torch::Tensor faces = torch::zeros({ntris, 3}, opt_long);
torch::Tensor ids = torch::zeros({ntris}, opt_long);
at::Tensor verts = at::zeros({ntris * 3, 3}, vol.options());
at::Tensor faces = at::zeros({ntris, 3}, opt_long);
at::Tensor ids = at::zeros({ntris}, opt_long);
return std::make_tuple(verts, faces, ids);
}

// Execute "CompactVoxelsKernel" kernel to compress voxels for accleration.
// This allows us to run triangle generation on only the occupied voxels.
auto d_compVoxelArray = torch::zeros({activeVoxels}, opt);
auto d_compVoxelArray = at::zeros({activeVoxels}, opt);
CompactVoxelsKernel<<<grid, threads, 0, stream>>>(
d_compVoxelArray.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
d_voxelOccupied.packed_accessor32<int, 1, at::RestrictPtrTraits>(),
Expand All @@ -519,7 +515,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
cudaDeviceSynchronize();

// Scan d_voxelVerts array to generate offsets of vertices for each voxel
auto d_voxelVertsScan = torch::zeros({numVoxels}, opt);
auto d_voxelVertsScan = at::zeros({numVoxels}, opt);
ThrustScanWrapper(
d_voxelVertsScan.data_ptr<int>(),
d_voxelVerts.data_ptr<int>(),
Expand All @@ -533,10 +529,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> MarchingCubesCuda(
// Execute "GenerateFacesKernel" kernel
// This runs only on the occupied voxels.
// It looks up the field values and generates the triangle data.
torch::Tensor verts = torch::zeros({totalVerts, 3}, vol.options());
torch::Tensor faces = torch::zeros({totalVerts / 3, 3}, opt_long);
at::Tensor verts = at::zeros({totalVerts, 3}, vol.options());
at::Tensor faces = at::zeros({totalVerts / 3, 3}, opt_long);

torch::Tensor ids = torch::zeros({totalVerts}, opt_long);
at::Tensor ids = at::zeros({totalVerts}, opt_long);

dim3 grid2((activeVoxels + threads - 1) / threads, 1, 1);
if (grid2.x > 65535) {
Expand Down

0 comments on commit 3388d3f

Please sign in to comment.