From eedd27e0f32bdf33f324238284c915dc6e1574d2 Mon Sep 17 00:00:00 2001 From: mosure Date: Sun, 5 Nov 2023 19:39:13 -0600 Subject: [PATCH] feat: radix pre-sort working (/w instability) --- src/render/gaussian.wgsl | 40 +++---- src/render/mod.rs | 226 +++++++++++++++++++++++++++++++++++++-- viewer/viewer.rs | 16 +-- 3 files changed, 244 insertions(+), 38 deletions(-) diff --git a/src/render/gaussian.wgsl b/src/render/gaussian.wgsl index 945c1ab1..ce434a7d 100644 --- a/src/render/gaussian.wgsl +++ b/src/render/gaussian.wgsl @@ -60,6 +60,7 @@ struct SortingSharedA { } var sorting_shared_a: SortingSharedA; +// TODO: resolve flickering (maybe more radix passes?) @compute @workgroup_size(#{RADIX_BASE}, #{RADIX_DIGIT_PLACES}) fn radix_sort_a( @builtin(local_invocation_id) gl_LocalInvocationID: vec3, @@ -76,10 +77,11 @@ fn radix_sort_a( continue; } var key: u32 = 0xFFFFFFFFu; // Stream compaction for frustum culling - let clip_space_pos = world_to_clip(points[entry_index].position.xyz); + let transformed_position = (uniforms.global_transform * points[entry_index].position).xyz; + let clip_space_pos = world_to_clip(transformed_position); if(in_frustum(clip_space_pos.xyz)) { - // key = bitcast(clip_space_pos.z); - key = u32(clip_space_pos.z * 0xFFFF.0) << 16u; + // key = bitcast(1.0 - clip_space_pos.z); + key = u32((1.0 - clip_space_pos.z) * 0xFFFF.0) << 16u; key |= u32((clip_space_pos.x * 0.5 + 0.5) * 0xFF.0) << 8u; key |= u32((clip_space_pos.y * 0.5 + 0.5) * 0xFF.0); } @@ -118,7 +120,7 @@ var sorting_shared_c: SortingSharedC; const NUM_BANKS: u32 = 16u; const LOG_NUM_BANKS: u32 = 4u; fn conflict_free_offset(n: u32) -> u32 { - return 0u; // n >> NUM_BANKS + n >> (2u * LOG_NUM_BANKS); + return n >> NUM_BANKS + n >> (2u * LOG_NUM_BANKS); } fn exclusive_scan(local_invocation_index: u32, value: u32) -> u32 { @@ -375,9 +377,9 @@ fn compute_cov2d(position: vec3, scale: vec3, rotation: vec4) -> #ifdef USE_AABB let W = transpose( mat3x3( - view.projection.x.xyz, - view.projection.y.xyz, - view.projection.z.xyz, + view.inverse_view.x.xyz, + view.inverse_view.y.xyz, + view.inverse_view.z.xyz, ) ); #endif @@ -385,9 +387,9 @@ fn compute_cov2d(position: vec3, scale: vec3, rotation: vec4) -> #ifdef USE_OBB let W = transpose( mat3x3( - view.projection.x.xyz, - view.projection.y.xyz, - view.projection.z.xyz, + view.inverse_view.x.xyz, + view.inverse_view.y.xyz, + view.inverse_view.z.xyz, ) ); #endif @@ -395,8 +397,8 @@ fn compute_cov2d(position: vec3, scale: vec3, rotation: vec4) -> let T = W * J; var cov = transpose(T) * transpose(Vrk) * T; - // cov[0][0] += 0.3f; - // cov[1][1] += 0.3f; + cov[0][0] += 0.3f; + cov[1][1] += 0.3f; return vec3(cov[0][0], cov[0][1], cov[1][1]); } @@ -498,14 +500,13 @@ fn vs_points( @builtin(vertex_index) vertex_index: u32, ) -> GaussianOutput { var output: GaussianOutput; - let splat_index = instance_index; - // let splat_index = sorted_entries[instance_index][1]; + let splat_index = sorted_entries[instance_index][1]; - // let discard_quad = sorted_entries[instance_index][0] == 0xFFFFFFFFu; - // if (discard_quad) { - // output.color = vec4(0.0, 0.0, 0.0, 0.0); - // return output; - // } + let discard_quad = sorted_entries[instance_index][0] == 0xFFFFFFFFu; + if (discard_quad) { + output.color = vec4(0.0, 0.0, 0.0, 0.0); + return output; + } let point = points[splat_index]; let transformed_position = (uniforms.global_transform * point.position).xyz; @@ -513,6 +514,7 @@ fn vs_points( let projected_position = world_to_clip(transformed_position); if (!in_frustum(projected_position.xyz)) { output.color = vec4(0.0, 0.0, 0.0, 0.0); + output.position = vec4(0.0, 0.0, 0.0, 0.0); return output; } diff --git a/src/render/mod.rs b/src/render/mod.rs index d540e9c9..4febd734 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -7,7 +7,10 @@ use bevy::{ HandleUntyped, LoadState, }, - core_pipeline::core_3d::Transparent3d, + core_pipeline::core_3d::{ + Transparent3d, + CORE_3D, + }, ecs::{ system::{ lifetimeless::*, @@ -44,7 +47,10 @@ use bevy::{ TrackedRenderPass, }, render_resource::*, - renderer::RenderDevice, + renderer::{ + RenderDevice, + RenderContext, + }, Render, RenderApp, RenderSet, @@ -54,6 +60,10 @@ use bevy::{ ViewUniforms, ViewUniformOffset, }, + render_graph::{ + self, + RenderGraphApp, + }, }, }; @@ -68,6 +78,10 @@ use crate::gaussian::{ const GAUSSIAN_SHADER_HANDLE: HandleUntyped = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 68294581); const SPHERICAL_HARMONICS_SHADER_HANDLE: HandleUntyped = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 834667312); +pub mod node { + pub const RADIX_SORT: &str = "radix_sort"; +} + #[derive(Default)] pub struct RenderPipelinePlugin; @@ -92,6 +106,17 @@ impl Plugin for RenderPipelinePlugin { app.add_plugins(UniformComponentPlugin::::default()); if let Ok(render_app) = app.get_sub_app_mut(RenderApp) { + render_app + .add_render_graph_node::( + CORE_3D, + node::RADIX_SORT, + ) + .add_render_graph_edge( + CORE_3D, + node::RADIX_SORT, + bevy::core_pipeline::core_3d::graph::node::PREPASS, + ); + render_app .add_render_command::() .init_resource::() @@ -222,12 +247,13 @@ fn queue_gaussians( &Handle, &GaussianCloudSettings, )>, - mut views: Query<(&ExtractedView, &mut RenderPhase)>, + mut views: Query<( + &ExtractedView, + &mut RenderPhase, + )>, ) { let draw_custom = transparent_3d_draw_functions.read().id::(); - // TODO: add compute pipelines to pipeline cache & compute phase - for (_view, mut transparent_phase) in &mut views { for (entity, cloud, settings) in &gaussian_splatting_bundles { if let Some(_cloud) = gaussian_clouds.get(cloud) { @@ -489,7 +515,7 @@ struct ShaderDefines { entries_per_invocation_c: u32, workgroup_invocations_a: u32, workgroup_invocations_c: u32, - _workgroup_entries_a: u32, + workgroup_entries_a: u32, workgroup_entries_c: u32, max_tile_count_c: u32, sorting_buffer_size: usize, @@ -506,7 +532,7 @@ impl Default for ShaderDefines { let entries_per_invocation_c = 4; let workgroup_invocations_a = radix_base * radix_digit_places; let workgroup_invocations_c = radix_base; - let _workgroup_entries_a = workgroup_invocations_a * entries_per_invocation_a; + let workgroup_entries_a = workgroup_invocations_a * entries_per_invocation_a; let workgroup_entries_c = workgroup_invocations_c * entries_per_invocation_c; let max_tile_count_c = (10000000 + workgroup_entries_c - 1) / workgroup_entries_c; let sorting_buffer_size = ( @@ -523,7 +549,7 @@ impl Default for ShaderDefines { entries_per_invocation_c, workgroup_invocations_a, workgroup_invocations_c, - _workgroup_entries_a, + workgroup_entries_a, workgroup_entries_c, max_tile_count_c, sorting_buffer_size, @@ -971,8 +997,6 @@ impl RenderCommand

for SetGaussianUniformBindGr } } -// TODO: add compute phase - pub struct DrawGaussianInstanced; impl RenderCommand

for DrawGaussianInstanced { type Param = SRes>; @@ -1004,11 +1028,191 @@ impl RenderCommand

for DrawGaussianInstanced { pass.set_bind_group(2, &bind_groups.cloud_bind_group, &[]); pass.set_bind_group(3, &bind_groups.sorted_bind_group, &[]); - pass.draw(0..4, 0..gpu_gaussian_cloud.count as u32); pass.draw_indirect(&gpu_gaussian_cloud.draw_indirect_buffer, 0); RenderCommandResult::Success } +} + + + +struct RadixSortNode { + gaussian_clouds: QueryState<( + &'static Handle, + &'static GaussianCloudBindGroup + )>, + initialized: bool, + pipeline_idx: Option, + view_bind_group: QueryState<( + &'static GaussianViewBindGroup, + &'static ViewUniformOffset, + )>, +} + +impl FromWorld for RadixSortNode { + fn from_world(world: &mut World) -> Self { + Self { + gaussian_clouds: world.query(), + initialized: false, + pipeline_idx: None, + view_bind_group: world.query(), + } + } +} + +impl render_graph::Node for RadixSortNode { + fn update(&mut self, world: &mut World) { + let pipeline = world.resource::(); + let pipeline_cache = world.resource::(); + + if !self.initialized { + let mut pipelines_loaded = true; + for sort_pipeline in pipeline.radix_sort_pipelines.iter() { + if let CachedPipelineState::Ok(_) = + pipeline_cache.get_compute_pipeline_state(*sort_pipeline) + { + continue; + } + + pipelines_loaded = false; + } + + self.initialized = pipelines_loaded; + + if !self.initialized { + return; + } + } + + if self.pipeline_idx.is_none() { + self.pipeline_idx = Some(0); + } else { + self.pipeline_idx = Some((self.pipeline_idx.unwrap() + 1) % pipeline.radix_sort_pipelines.len() as u32); + } + + self.gaussian_clouds.update_archetypes(world); + self.view_bind_group.update_archetypes(world); + } + + fn run( + &self, + _graph: &mut render_graph::RenderGraphContext, + render_context: &mut RenderContext, + world: &World, + ) -> Result<(), render_graph::NodeRunError> { + if !self.initialized || self.pipeline_idx.is_none() { + return Ok(()); + } + + let _idx = self.pipeline_idx.unwrap() as usize; // TODO: temporal sort + + let pipeline_cache = world.resource::(); + let pipeline = world.resource::(); + let gaussian_uniforms = world.resource::(); + + let command_encoder = render_context.command_encoder(); + + for ( + view_bind_group, + view_uniform_offset, + ) in self.view_bind_group.iter_manual(world) { + for ( + cloud_handle, + cloud_bind_group + ) in self.gaussian_clouds.iter_manual(world) { + let cloud = world.get_resource::>().unwrap().get(cloud_handle).unwrap(); + + let radix_digit_places = ShaderDefines::default().radix_digit_places; + + command_encoder.clear_buffer( + &cloud.sorting_global_buffer, + 0, + None, + ); + + { + let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); + + // TODO: view/global + pass.set_bind_group( + 0, + &view_bind_group.value, + &[view_uniform_offset.offset], + ); + pass.set_bind_group( + 1, + gaussian_uniforms.base_bind_group.as_ref().unwrap(), + &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex + ); + pass.set_bind_group( + 2, + &cloud_bind_group.cloud_bind_group, + &[] + ); + pass.set_bind_group( + 3, + &cloud_bind_group.radix_sort_bind_groups[1], + &[], + ); + + let radix_sort_a = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[0]).unwrap(); + pass.set_pipeline(radix_sort_a); + + let workgroup_entries_a = ShaderDefines::default().workgroup_entries_a; + pass.dispatch_workgroups((cloud.count + workgroup_entries_a - 1) / workgroup_entries_a, 1, 1); + + + let radix_sort_b = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[1]).unwrap(); + pass.set_pipeline(radix_sort_b); + + pass.dispatch_workgroups(1, radix_digit_places, 1); + } + + for pass_idx in 0..radix_digit_places { + if pass_idx > 0 { + let size = ShaderDefines::default().radix_base * ShaderDefines::default().max_tile_count_c * std::mem::size_of::() as u32; + command_encoder.clear_buffer( + &cloud.sorting_global_buffer, + 0, + std::num::NonZeroU64::new(size as u64).unwrap().into() + ); + } + + let mut pass = command_encoder.begin_compute_pass(&ComputePassDescriptor::default()); + + let radix_sort_c = pipeline_cache.get_compute_pipeline(pipeline.radix_sort_pipelines[2]).unwrap(); + pass.set_pipeline(&radix_sort_c); + + pass.set_bind_group( + 0, + &view_bind_group.value, + &[view_uniform_offset.offset], + ); + pass.set_bind_group( + 1, + gaussian_uniforms.base_bind_group.as_ref().unwrap(), + &[0], // TODO: fix transforms - dynamic offset using DynamicUniformIndex + ); + pass.set_bind_group( + 2, + &cloud_bind_group.cloud_bind_group, + &[] + ); + pass.set_bind_group( + 3, + &cloud_bind_group.radix_sort_bind_groups[pass_idx as usize], + &[], + ); + + let workgroup_entries_c = ShaderDefines::default().workgroup_entries_c; + pass.dispatch_workgroups(1, (cloud.count + workgroup_entries_c - 1) / workgroup_entries_c, 1); + } + } + } + + + Ok(()) + } } diff --git a/viewer/viewer.rs b/viewer/viewer.rs index f73883ad..4e7b4fb9 100644 --- a/viewer/viewer.rs +++ b/viewer/viewer.rs @@ -51,8 +51,8 @@ fn setup_gaussian_cloud( mut commands: Commands, asset_server: Res, mut gaussian_assets: ResMut>, - mut meshes: ResMut>, - mut materials: ResMut>, + // mut meshes: ResMut>, + // mut materials: ResMut>, ) { let cloud: Handle; let settings = GaussianCloudSettings { @@ -77,12 +77,12 @@ fn setup_gaussian_cloud( Name::new("gaussian_cloud"), )); - commands.spawn(PbrBundle { - mesh: meshes.add(Mesh::from(shape::Cube { size: 1.0 })), - material: materials.add(Color::rgb(0.8, 0.3, 0.6).into()), - transform: Transform::from_xyz(0.0, 0.0, 0.0), - ..default() - }); + // commands.spawn(PbrBundle { + // mesh: meshes.add(Mesh::from(shape::Cube { size: 1.0 })), + // material: materials.add(Color::rgb(0.8, 0.3, 0.6).into()), + // transform: Transform::from_xyz(0.0, 0.0, 0.0), + // ..default() + // }); commands.spawn(( Camera3dBundle {