diff --git a/src/gaussian/settings.rs b/src/gaussian/settings.rs index b9a941aa..10f6ca5c 100644 --- a/src/gaussian/settings.rs +++ b/src/gaussian/settings.rs @@ -25,7 +25,7 @@ pub enum GaussianCloudDrawMode { pub struct GaussianCloudSettings { pub aabb: bool, pub global_scale: f32, - pub global_transform: GlobalTransform, + pub transform: Transform, pub visualize_bounding_box: bool, pub visualize_depth: bool, pub sort_mode: SortMode, @@ -37,7 +37,7 @@ impl Default for GaussianCloudSettings { Self { aabb: false, global_scale: 1.0, - global_transform: Transform::IDENTITY.into(), + transform: Transform::IDENTITY, visualize_bounding_box: false, visualize_depth: false, sort_mode: SortMode::default(), diff --git a/src/render/bindings.wgsl b/src/render/bindings.wgsl index 701fed47..f605d7cf 100644 --- a/src/render/bindings.wgsl +++ b/src/render/bindings.wgsl @@ -8,7 +8,7 @@ @group(0) @binding(1) var globals: Globals; struct GaussianUniforms { - global_transform: mat4x4, + transform: mat4x4, global_scale: f32, count: u32, count_root_ceil: u32, diff --git a/src/render/gaussian.wgsl b/src/render/gaussian.wgsl index 5db51cd8..77709f37 100644 --- a/src/render/gaussian.wgsl +++ b/src/render/gaussian.wgsl @@ -144,6 +144,12 @@ fn compute_cov3d(scale: vec3, rotation: vec4) -> array { let y = rotation.z; let z = rotation.w; + let T = mat3x3( + gaussian_uniforms.transform[0].xyz, + gaussian_uniforms.transform[1].xyz, + gaussian_uniforms.transform[2].xyz, + ); + let R = mat3x3( 1.0 - 2.0 * (y * y + z * z), 2.0 * (x * y - r * z), @@ -160,14 +166,15 @@ fn compute_cov3d(scale: vec3, rotation: vec4) -> array { let M = S * R; let Sigma = transpose(M) * M; + let TS = T * Sigma * transpose(T); return array( - Sigma[0][0], - Sigma[0][1], - Sigma[0][2], - Sigma[1][1], - Sigma[1][2], - Sigma[2][2], + TS[0][0], + TS[0][1], + TS[0][2], + TS[1][1], + TS[1][2], + TS[2][2], ); } @@ -320,7 +327,7 @@ fn vs_points( let position = vec4(get_position(splat_index), 1.0); - let transformed_position = (gaussian_uniforms.global_transform * position).xyz; + let transformed_position = (gaussian_uniforms.transform * position).xyz; let projected_position = world_to_clip(transformed_position); discard_quad |= !in_frustum(projected_position.xyz); @@ -353,8 +360,8 @@ fn vs_points( let first_position = vec4(get_position(get_entry(1u).value), 1.0); let last_position = vec4(get_position(get_entry(gaussian_uniforms.count - 1u).value), 1.0); - let min_position = (gaussian_uniforms.global_transform * first_position).xyz; - let max_position = (gaussian_uniforms.global_transform * last_position).xyz; + let min_position = (gaussian_uniforms.transform * first_position).xyz; + let max_position = (gaussian_uniforms.transform * last_position).xyz; let camera_position = view.world_position; diff --git a/src/render/mod.rs b/src/render/mod.rs index e59d1156..ab18358b 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -702,7 +702,7 @@ pub fn extract_gaussians( let cloud = gaussian_cloud_res.get(cloud_handle).unwrap(); let settings_uniform = GaussianCloudUniform { - transform: settings.global_transform.compute_matrix(), + transform: settings.transform.compute_matrix(), global_scale: settings.global_scale, count: cloud.count as u32, count_root_ceil: (cloud.count as f32).sqrt().ceil() as u32, diff --git a/src/sort/radix.wgsl b/src/sort/radix.wgsl index c010ca93..f06b7dfa 100644 --- a/src/sort/radix.wgsl +++ b/src/sort/radix.wgsl @@ -67,7 +67,7 @@ fn radix_sort_a( } var key: u32 = 0xFFFFFFFFu; // Stream compaction for frustum culling let position = vec4(get_position(entry_index), 1.0); - let transformed_position = (gaussian_uniforms.global_transform * position).xyz; + let transformed_position = (gaussian_uniforms.transform * position).xyz; let clip_space_pos = world_to_clip(transformed_position); if(in_frustum(clip_space_pos.xyz)) { // key = bitcast(1.0 - clip_space_pos.z); diff --git a/src/sort/rayon.rs b/src/sort/rayon.rs index a2890e68..a1f98190 100644 --- a/src/sort/rayon.rs +++ b/src/sort/rayon.rs @@ -1,6 +1,7 @@ use bevy::{ prelude::*, asset::LoadState, + math::Vec3A, utils::Instant, }; @@ -36,10 +37,10 @@ pub fn rayon_sort( &GaussianCloudSettings, )>, cameras: Query<( - &GlobalTransform, + &Transform, &Camera3d, )>, - mut last_camera_position: Local, + mut last_camera_position: Local, mut last_sort_time: Local>, mut period: Local, mut sort_done: Local, @@ -63,7 +64,7 @@ pub fn rayon_sort( camera_transform, _camera, ) in cameras.iter() { - let camera_position = camera_transform.compute_transform().translation; + let camera_position = camera_transform.compute_affine().translation; let camera_movement = *last_camera_position != camera_position; if camera_movement { @@ -104,7 +105,9 @@ pub fn rayon_sort( .zip(sorted_entries.sorted.par_iter_mut()) .enumerate() .for_each(|(idx, (position, sort_entry))| { - let position = Vec3::from_slice(position.as_ref()); + let position = Vec3A::from_slice(position.as_ref()); + let position = settings.transform.compute_affine().transform_point3a(position); + let delta = camera_position - position; sort_entry.key = bytemuck::cast(delta.length_squared()); diff --git a/src/sort/std.rs b/src/sort/std.rs index 14816e83..2b366536 100644 --- a/src/sort/std.rs +++ b/src/sort/std.rs @@ -1,6 +1,7 @@ use bevy::{ prelude::*, asset::LoadState, + math::Vec3A, utils::Instant, }; @@ -35,10 +36,10 @@ pub fn std_sort( &GaussianCloudSettings, )>, cameras: Query<( - &GlobalTransform, + &Transform, &Camera3d, )>, - mut last_camera_position: Local, + mut last_camera_position: Local, mut last_sort_time: Local>, mut period: Local, mut camera_debounce: Local, @@ -61,7 +62,7 @@ pub fn std_sort( camera_transform, _camera, ) in cameras.iter() { - let camera_position = camera_transform.compute_transform().translation; + let camera_position = camera_transform.compute_affine().translation; let camera_movement = *last_camera_position != camera_position; if camera_movement { @@ -107,7 +108,9 @@ pub fn std_sort( .zip(sorted_entries.sorted.iter_mut()) .enumerate() .for_each(|(idx, (position, sort_entry))| { - let position = Vec3::from_slice(position.as_ref()); + let position = Vec3A::from_slice(position.as_ref()); + let position = settings.transform.compute_affine().transform_point3a(position); + let delta = camera_position - position; sort_entry.key = bytemuck::cast(delta.length_squared());