Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: stream compaction overdraw #35

Merged
merged 3 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/render/bindings.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ struct DrawIndirect {
base_instance: u32,
}
struct SortingGlobal {
status_counters: array<array<atomic<u32>, #{RADIX_BASE}>, #{MAX_TILE_COUNT_C}>,
digit_histogram: array<array<atomic<u32>, #{RADIX_BASE}>, #{RADIX_DIGIT_PLACES}>,
assignment_counter: atomic<u32>,
}
Expand All @@ -50,7 +49,8 @@ struct Entry {

@group(3) @binding(0) var<uniform> sorting_pass_index: u32;
@group(3) @binding(1) var<storage, read_write> sorting: SortingGlobal;
@group(3) @binding(2) var<storage, read_write> draw_indirect: DrawIndirect;
@group(3) @binding(3) var<storage, read_write> input_entries: array<Entry>;
@group(3) @binding(4) var<storage, read_write> output_entries: array<Entry>;
@group(3) @binding(5) var<storage, read> sorted_entries: array<Entry>;
@group(3) @binding(2) var<storage, read_write> status_counters: array<array<atomic<u32>, #{RADIX_BASE}>>;
@group(3) @binding(3) var<storage, read_write> draw_indirect: DrawIndirect;
@group(3) @binding(4) var<storage, read_write> input_entries: array<Entry>;
@group(3) @binding(5) var<storage, read_write> output_entries: array<Entry>;
@group(3) @binding(6) var<storage, read> sorted_entries: array<Entry>;
3 changes: 2 additions & 1 deletion src/render/gaussian.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,10 @@ fn vs_points(
var output: GaussianOutput;
let splat_index = sorted_entries[instance_index][1];

let discard_quad = sorted_entries[instance_index][0] == 0xFFFFFFFFu;
let discard_quad = sorted_entries[instance_index][0] == 0xFFFFFFFFu || splat_index == 0u;
if (discard_quad) {
output.color = vec4<f32>(0.0, 0.0, 0.0, 0.0);
output.position = vec4<f32>(0.0, 0.0, 0.0, 0.0);
return output;
}

Expand Down
105 changes: 72 additions & 33 deletions src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,11 @@ pub struct GpuGaussianSplattingBundle {
#[derive(Debug, Clone)]
pub struct GpuGaussianCloud {
pub gaussian_buffer: Buffer,
pub count: u32,
pub count: usize,

pub draw_indirect_buffer: Buffer,
pub sorting_global_buffer: Buffer,
pub sorting_status_counter_buffer: Buffer,
pub sorting_pass_buffers: [Buffer; 4],
pub entry_buffer_a: Buffer,
pub entry_buffer_b: Buffer,
Expand All @@ -210,16 +211,22 @@ impl RenderAsset for GaussianCloud {
usage: BufferUsages::VERTEX | BufferUsages::COPY_DST | BufferUsages::STORAGE,
});

let count = gaussian_cloud.0.len() as u32;
let count = gaussian_cloud.0.len();

// TODO: derive sorting_buffer_size from cloud count (with possible rounding to next power of 2)
let sorting_global_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("sorting global buffer"),
size: ShaderDefines::default().sorting_buffer_size as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let sorting_status_counter_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("status counters buffer"),
size: ShaderDefines::default().sorting_status_counters_buffer_size(count) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_DST | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let draw_indirect_buffer = render_device.create_buffer(&BufferDescriptor {
label: Some("draw indirect buffer"),
size: std::mem::size_of::<wgpu::util::DrawIndirect>() as u64,
Expand All @@ -241,14 +248,14 @@ impl RenderAsset for GaussianCloud {

let entry_buffer_a = render_device.create_buffer(&BufferDescriptor {
label: Some("entry buffer a"),
size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64,
size: (count * std::mem::size_of::<(u32, u32)>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});

let entry_buffer_b = render_device.create_buffer(&BufferDescriptor {
label: Some("entry buffer b"),
size: (count as usize * std::mem::size_of::<(u32, u32)>()) as u64,
size: (count * std::mem::size_of::<(u32, u32)>()) as u64,
usage: BufferUsages::STORAGE | BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
Expand All @@ -258,6 +265,7 @@ impl RenderAsset for GaussianCloud {
count,
draw_indirect_buffer,
sorting_global_buffer,
sorting_status_counter_buffer,
sorting_pass_buffers,
entry_buffer_a,
entry_buffer_b,
Expand Down Expand Up @@ -409,9 +417,20 @@ impl FromWorld for GaussianCloudPipeline {
count: None,
};

let draw_indirect_buffer_entry = BindGroupLayoutEntry {
let sorting_status_counters_buffer_entry = BindGroupLayoutEntry {
binding: 2,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: BufferSize::new(ShaderDefines::default().sorting_status_counters_buffer_size(1) as u64),
},
count: None,
};

let draw_indirect_buffer_entry = BindGroupLayoutEntry {
binding: 3,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
Expand All @@ -434,9 +453,10 @@ impl FromWorld for GaussianCloudPipeline {
count: None,
},
sorting_buffer_entry,
sorting_status_counters_buffer_entry,
draw_indirect_buffer_entry,
BindGroupLayoutEntry {
binding: 3,
binding: 4,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
Expand All @@ -446,7 +466,7 @@ impl FromWorld for GaussianCloudPipeline {
count: None,
},
BindGroupLayoutEntry {
binding: 4,
binding: 5,
visibility: ShaderStages::COMPUTE,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: false },
Expand All @@ -462,7 +482,7 @@ impl FromWorld for GaussianCloudPipeline {
label: Some("sorted_layout"),
entries: &vec![
BindGroupLayoutEntry {
binding: 5,
binding: 6,
visibility: ShaderStages::VERTEX_FRAGMENT,
ty: BindingType::Buffer {
ty: BufferBindingType::Storage { read_only: true },
Expand All @@ -474,7 +494,7 @@ impl FromWorld for GaussianCloudPipeline {
],
});

let compute_layout = vec![
let sorting_layout = vec![
view_layout.clone(),
gaussian_uniform_layout.clone(),
gaussian_cloud_layout.clone(),
Expand All @@ -485,7 +505,7 @@ impl FromWorld for GaussianCloudPipeline {
let pipeline_cache = render_world.resource::<PipelineCache>();
let radix_sort_a = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_a".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand All @@ -494,7 +514,7 @@ impl FromWorld for GaussianCloudPipeline {

let radix_sort_b = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_b".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand All @@ -503,7 +523,7 @@ impl FromWorld for GaussianCloudPipeline {

let radix_sort_c = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("radix_sort_c".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: RADIX_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand All @@ -513,7 +533,7 @@ impl FromWorld for GaussianCloudPipeline {

let temporal_sort_flip = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("temporal_sort_flip".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: TEMPORAL_SORT_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand All @@ -522,7 +542,7 @@ impl FromWorld for GaussianCloudPipeline {

let temporal_sort_flop = pipeline_cache.queue_compute_pipeline(ComputePipelineDescriptor {
label: Some("temporal_sort_flop".into()),
layout: compute_layout.clone(),
layout: sorting_layout.clone(),
push_constant_ranges: vec![],
shader: TEMPORAL_SORT_SHADER_HANDLE,
shader_defs: shader_defs.clone(),
Expand Down Expand Up @@ -560,12 +580,21 @@ struct ShaderDefines {
workgroup_invocations_c: u32,
workgroup_entries_a: u32,
workgroup_entries_c: u32,
max_tile_count_c: u32,
sorting_buffer_size: usize,
sorting_buffer_size: u32,

temporal_sort_window_size: u32,
}

impl ShaderDefines {
fn max_tile_count(&self, count: usize) -> u32 {
(count as u32 + self.workgroup_entries_c - 1) / self.workgroup_entries_c
}

fn sorting_status_counters_buffer_size(&self, count: usize) -> usize {
self.radix_base as usize * self.max_tile_count(count) as usize * std::mem::size_of::<u32>()
}
}

impl Default for ShaderDefines {
fn default() -> Self {
let radix_bits_per_digit = 8;
Expand All @@ -577,10 +606,8 @@ impl Default for ShaderDefines {
let workgroup_invocations_c = radix_base;
let workgroup_entries_a = workgroup_invocations_a * entries_per_invocation_a;
let workgroup_entries_c = workgroup_invocations_c * entries_per_invocation_c;
let max_tile_count_c = (10000000 + workgroup_entries_c - 1) / workgroup_entries_c;
let sorting_buffer_size = radix_base as usize *
(radix_digit_places as usize + max_tile_count_c as usize) *
std::mem::size_of::<u32>() + 5 * std::mem::size_of::<u32>();
let sorting_buffer_size = radix_base * radix_digit_places *
std::mem::size_of::<u32>() as u32 + 5 * std::mem::size_of::<u32>() as u32;

Self {
radix_bits_per_digit,
Expand All @@ -592,7 +619,6 @@ impl Default for ShaderDefines {
workgroup_invocations_c,
workgroup_entries_a,
workgroup_entries_c,
max_tile_count_c,
sorting_buffer_size,

temporal_sort_window_size: 16,
Expand All @@ -615,7 +641,6 @@ fn shader_defs(
ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_A".into(), defines.workgroup_invocations_a),
ShaderDefVal::UInt("WORKGROUP_INVOCATIONS_C".into(), defines.workgroup_invocations_c),
ShaderDefVal::UInt("WORKGROUP_ENTRIES_C".into(), defines.workgroup_entries_c),
ShaderDefVal::UInt("MAX_TILE_COUNT_C".into(), defines.max_tile_count_c),

ShaderDefVal::UInt("TEMPORAL_SORT_WINDOW_SIZE".into(), defines.temporal_sort_window_size),
];
Expand Down Expand Up @@ -833,8 +858,17 @@ pub fn queue_gaussian_bind_group(
}),
};

let draw_indirect_entry = BindGroupEntry {
let sorting_status_counters_entry = BindGroupEntry {
binding: 2,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.sorting_status_counter_buffer,
offset: 0,
size: BufferSize::new(cloud.sorting_status_counter_buffer.size()),
}),
};

let draw_indirect_entry = BindGroupEntry {
binding: 3,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.draw_indirect_buffer,
offset: 0,
Expand All @@ -857,9 +891,10 @@ pub fn queue_gaussian_bind_group(
}),
},
sorting_global_entry.clone(),
sorting_status_counters_entry.clone(),
draw_indirect_entry.clone(),
BindGroupEntry {
binding: 3,
binding: 4,
resource: BindingResource::Buffer(BufferBinding {
buffer: if idx % 2 == 0 {
&cloud.entry_buffer_a
Expand All @@ -871,7 +906,7 @@ pub fn queue_gaussian_bind_group(
}),
},
BindGroupEntry {
binding: 4,
binding: 5,
resource: BindingResource::Buffer(BufferBinding {
buffer: if idx % 2 == 0 {
&cloud.entry_buffer_b
Expand Down Expand Up @@ -910,7 +945,7 @@ pub fn queue_gaussian_bind_group(
&gaussian_cloud_pipeline.sorted_layout,
&[
BindGroupEntry {
binding: 5,
binding: 6,
resource: BindingResource::Buffer(BufferBinding {
buffer: &cloud.entry_buffer_a,
offset: 0,
Expand Down Expand Up @@ -1173,6 +1208,12 @@ impl render_graph::Node for RadixSortNode {
None,
);

command_encoder.clear_buffer(
&cloud.sorting_status_counter_buffer,
0,
None,
);

command_encoder.clear_buffer(
&cloud.draw_indirect_buffer,
0,
Expand Down Expand Up @@ -1208,7 +1249,7 @@ impl render_graph::Node for RadixSortNode {
pass.set_pipeline(radix_sort_a);

let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a;
pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);
pass.dispatch_workgroups((cloud.count as u32 + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1);


let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap();
Expand All @@ -1219,12 +1260,10 @@ impl render_graph::Node for RadixSortNode {

for pass_idx in 0..radix_digit_places {
if pass_idx > 0 {
// clear SortingGlobal.status_counters
let size = (ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c) as u64 * std::mem::size_of::<u32>() as u64;
command_encoder.clear_buffer(
&cloud.sorting_global_buffer,
&cloud.sorting_status_counter_buffer,
0,
std::num::NonZeroU64::new(size).unwrap().into()
None,
);
}

Expand Down Expand Up @@ -1255,7 +1294,7 @@ impl render_graph::Node for RadixSortNode {
);

let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c;
pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1);
pass.dispatch_workgroups(1, (cloud.count as u32 + workgroup_entries_c - 1) / workgroup_entries_c, 1);
}
}
}
Expand Down
Loading