diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 00000000..43fc9144 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,8 @@ +* text=auto eol=lf +*.{cmd,[cC][mM][dD]} text eol=crlf +*.{bat,[bB][aA][tT]} text eol=crlf +*.sh text eol=lf +*.conf text eol=lf + +*.ply binary +*.splat binary diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 0025690b..7deb85cd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -19,7 +19,7 @@ jobs: os: [windows-latest, macos-latest] runs-on: ${{ matrix.os }} - timeout-minutes: 60 + timeout-minutes: 120 steps: - uses: actions/checkout@v3 @@ -36,5 +36,8 @@ jobs: - name: build run: cargo build - - name: run tests + - name: lint + run: cargo clippy + + - name: test run: cargo test diff --git a/.gitignore b/.gitignore index 4bcdf229..7012bb23 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ Cargo.lock *.ply + +.DS_Store diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 00000000..841cab71 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,73 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'bevy_gaussian_splatting'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=bevy_gaussian_splatting" + ], + "filter": { + "name": "bevy_gaussian_splatting", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}", + "env": { + "CARGO_MANIFEST_DIR": "${workspaceFolder}", + } + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug executable 'bevy_gaussian_splatting'", + "cargo": { + "args": [ + "build", + "--bin=bevy_gaussian_splatting", + "--package=bevy_gaussian_splatting" + ], + "filter": { + "name": "bevy_gaussian_splatting", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}", + "env": { + "CARGO_MANIFEST_DIR": "${workspaceFolder}", + } + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in executable 'bevy_gaussian_splatting'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bin=bevy_gaussian_splatting", + "--package=bevy_gaussian_splatting" + ], + "filter": { + "name": "bevy_gaussian_splatting", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}", + "env": { + "CARGO_MANIFEST_DIR": "${workspaceFolder}", + } + } + ] +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index e7fbe96d..16c501b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,6 +15,7 @@ exclude = [".devcontainer", ".github", "docs", "dist", "build", "assets", "credi [dependencies] bevy = "0.11.2" bevy_panorbit_camera = "0.8.0" +bytemuck = "1.14.0" ply-rs = "0.1.3" diff --git a/README.md b/README.md index 8b263dac..d9f7905b 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# bevy_gaussian_splatting +# bevy_gaussian_splatting 🌌 [![test](https://github.com/mosure/bevy_gaussian_splatting/workflows/test/badge.svg)](https://github.com/Mosure/bevy_gaussian_splatting/actions?query=workflow%3Atest) [![GitHub License](https://img.shields.io/github/license/mosure/bevy_gaussian_splatting)](https://raw.githubusercontent.com/mosure/bevy_gaussian_splatting/main/LICENSE) @@ -15,8 +15,8 @@ bevy gaussian splatting render pipeline plugin ## capabilities - [ ] bevy gaussian cloud render pipeline +- [ ] 4D gaussian clouds via morph targets - [ ] bevy 3D camera to gaussian cloud pipeline -- [ ] 4D gaussian clouds ## usage @@ -37,7 +37,7 @@ fn setup_gaussian_cloud( asset_server: Res, ) { commands.spawn(GaussianSplattingBundle { - verticies: asset_server.load("scenes/test.ply"), + cloud: asset_server.load("scenes/icecream.ply"), ..Default::default() }); @@ -56,10 +56,17 @@ fn setup_gaussian_cloud( # credits - [bevy](https://github.com/bevyengine/bevy) +- [bevy-hanabi](https://github.com/djeedai/bevy_hanabi) - [diff-gaussian-rasterization](https://github.com/graphdeco-inria/diff-gaussian-rasterization) +- [dreamgaussian](https://github.com/dreamgaussian/dreamgaussian) - [dynamic-3d-gaussians](https://github.com/JonathonLuiten/Dynamic3DGaussians) - [gaussian-splatting](https://github.com/graphdeco-inria/gaussian-splatting) - [gaussian-splatting-web](https://github.com/cvlab-epfl/gaussian-splatting-web) +- [making gaussian splats smaller](https://aras-p.info/blog/2023/09/13/Making-Gaussian-Splats-smaller/) +- [onesweep](https://arxiv.org/ftp/arxiv/papers/2206/2206.01784.pdf) - [point-visualizer](https://github.com/mosure/point-visualizer) - [rusty-automata](https://github.com/mosure/rusty-automata) +- [splat](https://github.com/antimatter15/splat) +- [splatter](https://github.com/Lichtso/splatter) - [sturdy-dollop](https://github.com/mosure/sturdy-dollop) +- [taichi_3d_gaussian_splatting](https://github.com/wanmeihuali/taichi_3d_gaussian_splatting) diff --git a/assets/scenes/icecream.ply b/assets/scenes/icecream.ply new file mode 100644 index 00000000..84b539ff Binary files /dev/null and b/assets/scenes/icecream.ply differ diff --git a/src/gaussian.rs b/src/gaussian.rs index aae397a4..d9742220 100644 --- a/src/gaussian.rs +++ b/src/gaussian.rs @@ -1,6 +1,9 @@ -use std::io::{ - BufReader, - Cursor, +use std::{ + io::{ + BufReader, + Cursor, + }, + marker::Copy, }; use bevy::{ @@ -10,19 +13,21 @@ use bevy::{ LoadContext, LoadedAsset, }, - reflect::TypeUuid, + reflect::{ + TypePath, + TypeUuid, + }, + render::render_resource::ShaderType, utils::BoxedFuture, }; +use bytemuck::{ + Pod, + Zeroable, +}; use crate::ply::parse_ply; -#[derive(Clone, Debug, Default, Reflect)] -pub struct AnisotropicCovariance { - pub mean: Vec3, - pub covariance: Mat3, -} - const fn num_sh_coefficients(degree: usize) -> usize { if degree == 0 { 1 @@ -32,30 +37,101 @@ const fn num_sh_coefficients(degree: usize) -> usize { } const SH_DEGREE: usize = 3; pub const MAX_SH_COEFF_COUNT: usize = num_sh_coefficients(SH_DEGREE) * 3; -#[derive(Clone, Debug, Reflect)] +#[derive(Clone, Copy, ShaderType, Pod, Zeroable)] +#[repr(C)] pub struct SphericalHarmonicCoefficients { - pub coefficients: [Vec3; MAX_SH_COEFF_COUNT], + pub coefficients: [f32; MAX_SH_COEFF_COUNT], } impl Default for SphericalHarmonicCoefficients { fn default() -> Self { Self { - coefficients: [Vec3::ZERO; MAX_SH_COEFF_COUNT], + coefficients: [0.0; MAX_SH_COEFF_COUNT], } } } -#[derive(Clone, Debug, Default, Reflect)] +#[derive(Clone, Default, Copy, ShaderType, Pod, Zeroable)] +#[repr(C)] pub struct Gaussian { - pub normal: Vec3, + //pub anisotropic_covariance: AnisotropicCovariance, + //pub normal: Vec3, + pub rotation: [f32; 4], + pub position: Vec3, + pub scale: Vec3, pub opacity: f32, - pub transform: Transform, - pub anisotropic_covariance: AnisotropicCovariance, pub spherical_harmonic: SphericalHarmonicCoefficients, + padding: f32, } -#[derive(Clone, Debug, Reflect, TypeUuid)] +#[derive(Clone, TypeUuid, TypePath)] #[uuid = "ac2f08eb-bc32-aabb-ff21-51571ea332d5"] -pub struct GaussianCloud(Vec); +pub struct GaussianCloud(pub Vec); + +impl GaussianCloud { + pub fn test_model() -> Self { + let origin = Gaussian { + rotation: [ + 1.0, + 0.0, + 0.0, + 0.0, + ], + position: Vec3::new(0.0, 0.0, 0.0), + scale: Vec3::new(0.5, 0.5, 0.5), + opacity: 0.8, + spherical_harmonic: SphericalHarmonicCoefficients{ + coefficients: [ + 1.0, 0.0, 1.0, + 0.0, 0.5, 0.0, + 0.3, 0.2, 0.0, + 0.4, 0.0, 0.2, + 0.1, 0.0, 0.0, + 0.0, 0.3, 0.3, + 0.0, 1.0, 1.0, + 0.3, 0.0, 0.0, + 0.0, 0.0, 0.0, + 0.0, 0.3, 1.0, + 0.5, 0.3, 0.0, + 0.2, 0.3, 0.1, + 0.6, 0.3, 0.1, + 0.0, 0.3, 0.2, + 0.0, 0.5, 0.3, + 0.6, 0.1, 0.2, + ], + }, + padding: 0.0, + }; + let mut cloud = GaussianCloud(Vec::new()); + + for &x in [-1.0, 1.0].iter() { + for &y in [-1.0, 1.0].iter() { + for &z in [-1.0, 1.0].iter() { + let mut g = origin.clone(); + g.position = Vec3::new(x, y, z); + cloud.0.push(g); + } + } + } + + cloud + } +} + + +#[derive(Component, Reflect, Clone)] +pub struct GaussianCloudSettings { + pub global_scale: f32, + pub global_transform: GlobalTransform, +} + +impl Default for GaussianCloudSettings { + fn default() -> Self { + Self { + global_scale: 1.0, + global_transform: Transform::IDENTITY.into(), + } + } +} #[derive(Default)] diff --git a/src/lib.rs b/src/lib.rs index d6c824df..e6a33107 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,11 @@ use bevy::prelude::*; -use gaussian::{ +pub use gaussian::{ + Gaussian, GaussianCloud, GaussianCloudLoader, + GaussianCloudSettings, + SphericalHarmonicCoefficients, }; use render::RenderPipelinePlugin; @@ -13,13 +16,17 @@ pub mod render; pub mod utils; -#[derive(Component, Default)] +#[derive(Bundle, Default, Reflect)] pub struct GaussianSplattingBundle { - pub transform: Transform, - pub verticies: Handle, + pub settings: GaussianCloudSettings, // TODO: implement global transform + pub cloud: Handle, } -// TODO: add render pipeline config + +#[derive(Component, Default)] +struct GaussianSplattingCamera; +// TODO: filter camera 3D entities + pub struct GaussianSplattingPlugin; impl Plugin for GaussianSplattingPlugin { @@ -27,8 +34,10 @@ impl Plugin for GaussianSplattingPlugin { app.add_asset::(); app.init_asset_loader::(); - app.add_plugins(RenderPipelinePlugin); + app.register_type::(); - // TODO: add GaussianSplattingBundle system + app.add_plugins(( + RenderPipelinePlugin, + )); } } diff --git a/src/main.rs b/src/main.rs index 789d7d2a..3f8cc1f0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,6 +12,7 @@ use bevy_panorbit_camera::{ }; use bevy_gaussian_splatting::{ + GaussianCloud, GaussianSplattingBundle, GaussianSplattingPlugin, utils::setup_hooks, @@ -41,10 +42,13 @@ impl Default for GaussianSplattingViewer { fn setup_gaussian_cloud( mut commands: Commands, - asset_server: Res, + _asset_server: Res, + mut gaussian_assets: ResMut>, ) { + let cloud = gaussian_assets.add(GaussianCloud::test_model()); commands.spawn(GaussianSplattingBundle { - verticies: asset_server.load("scenes/test.ply"), + cloud, + // cloud: _asset_server.load("scenes/icecream.ply"), ..Default::default() }); diff --git a/src/ply.rs b/src/ply.rs index 3e56c1ef..ab9fa032 100644 --- a/src/ply.rs +++ b/src/ply.rs @@ -9,10 +9,7 @@ use ply_rs::{ parser::Parser, }; -use crate::gaussian::{ - Gaussian, - MAX_SH_COEFF_COUNT, -}; +use crate::gaussian::Gaussian; impl PropertyAccess for Gaussian { @@ -22,40 +19,32 @@ impl PropertyAccess for Gaussian { fn set_property(&mut self, key: String, property: Property) { match (key.as_ref(), property) { - ("x", Property::Float(v)) => self.transform.translation.x = v, - ("y", Property::Float(v)) => self.transform.translation.y = v, - ("z", Property::Float(v)) => self.transform.translation.z = v, - ("nx", Property::Float(v)) => self.normal.x = v, - ("ny", Property::Float(v)) => self.normal.y = v, - ("nz", Property::Float(v)) => self.normal.z = v, - ("f_dc_0", Property::Float(v)) => self.spherical_harmonic.coefficients[0].x = v, - ("f_dc_1", Property::Float(v)) => self.spherical_harmonic.coefficients[0].y = v, - ("f_dc_2", Property::Float(v)) => self.spherical_harmonic.coefficients[0].z = v, - ("opacity", Property::Float(v)) => self.opacity = v, - ("scale_0", Property::Float(v)) => self.transform.scale.x = v, - ("scale_1", Property::Float(v)) => self.transform.scale.y = v, - ("scale_2", Property::Float(v)) => self.transform.scale.z = v, - ("rot_0", Property::Float(v)) => self.transform.rotation.x = v, - ("rot_1", Property::Float(v)) => self.transform.rotation.y = v, - ("rot_2", Property::Float(v)) => self.transform.rotation.z = v, - ("rot_3", Property::Float(v)) => self.transform.rotation.w = v, + ("x", Property::Float(v)) => self.position.x = v, + ("y", Property::Float(v)) => self.position.y = v, + ("z", Property::Float(v)) => self.position.z = v, + // ("nx", Property::Float(v)) => self.normal.x = v, + // ("ny", Property::Float(v)) => self.normal.y = v, + // ("nz", Property::Float(v)) => self.normal.z = v, + ("f_dc_0", Property::Float(v)) => self.spherical_harmonic.coefficients[0] = v, + ("f_dc_1", Property::Float(v)) => self.spherical_harmonic.coefficients[1] = v, + ("f_dc_2", Property::Float(v)) => self.spherical_harmonic.coefficients[2] = v, + ("opacity", Property::Float(v)) => self.opacity = 1.0 / (1.0 + (-v).exp()), + ("scale_0", Property::Float(v)) => self.scale.x = v.exp(), // TODO: variance cap: https://github.com/Lichtso/splatter/blob/c6b7a3894c25578cd29c9761619e4f194449e389/src/scene.rs#L235 + ("scale_1", Property::Float(v)) => self.scale.y = v.exp(), + ("scale_2", Property::Float(v)) => self.scale.z = v.exp(), + ("rot_0", Property::Float(v)) => self.rotation[0] = v, + ("rot_1", Property::Float(v)) => self.rotation[1] = v, + ("rot_2", Property::Float(v)) => self.rotation[2] = v, + ("rot_3", Property::Float(v)) => self.rotation[3] = v, (_, Property::Float(v)) if key.starts_with("f_rest_") => { let i = key[7..].parse::().unwrap(); - let sh_upper_bound = MAX_SH_COEFF_COUNT - 3; match i { - _ if i < sh_upper_bound => { - let i = i + 3; - let j = i / 3; - let k = i % 3; - - // TODO: verify this is the correct sh order - self.spherical_harmonic.coefficients[j][k] = v; - }, - _ => { - println!("unmapped property: {}", key); - println!("value: {}", v); + _ if i + 3 < self.spherical_harmonic.coefficients.len() => { + // TODO: verify this is the correct sh order (packed not planar) + self.spherical_harmonic.coefficients[i + 3] = v; }, + _ => { }, } } (_, _) => {}, diff --git a/src/render/gaussian.wgsl b/src/render/gaussian.wgsl index 4537977c..18637482 100644 --- a/src/render/gaussian.wgsl +++ b/src/render/gaussian.wgsl @@ -1,49 +1,46 @@ -#import bevy_gaussian_splatting::spherical_harmonics compute_color_from_sh_3_degree +#import bevy_render::globals Globals +#import bevy_render::view View + +#import bevy_gaussian_splatting::spherical_harmonics spherical_harmonics_lookup struct GaussianInput { - @location(0) position: vec3, - @location(1) log_scale: vec3, - @location(2) rot: vec4, - @location(3) opacity_logit: f32, - sh: array, n_sh_coeffs>, + @location(0) rotation: vec4, + @location(1) position: vec3, + @location(2) scale: vec3, + @location(3) opacity: f32, + sh: array, }; struct GaussianOutput { @builtin(position) position: vec4, - @location(0) color: vec3, - @location(1) uv: vec2, - @location(2) conic_and_opacity: vec4, + @location(0) @interpolate(flat) color: vec4, + @location(1) @interpolate(flat) conic: vec3, + @location(2) @interpolate(linear) uv: vec2, }; -struct SceneUniforms { - viewMatrix: mat4x4, - projMatrix: mat4x4, - camera_position: vec3, - tan_fovx: f32, - tan_fovy: f32, - focal_x: f32, - focal_y: f32, - scale_modifier: f32, +struct GaussianUniforms { + global_scale: f32, + transform: f32, }; -fn sigmoid(x: f32) -> f32 { - if (x >= 0.0) { - return 1.0 / (1.0 + exp(-x)); - } else { - let z = exp(x); - return z / (1.0 + z); - } -} +@group(0) @binding(0) var view: View; +@group(0) @binding(1) var globals: Globals; + +@group(1) @binding(0) var uniforms: GaussianUniforms; -// TODO: precompute cov3d -fn compute_cov3d(log_scale: vec3, rot: vec4) -> array { - let modifier = uniforms.scale_modifier; +@group(2) @binding(0) var points: array; + + +// https://github.com/cvlab-epfl/gaussian-splatting-web/blob/905b3c0fb8961e42c79ef97e64609e82383ca1c2/src/shaders.ts#L185 +// TODO: precompute +fn compute_cov3d(scale: vec3, rot: vec4) -> array { + let modifier = uniforms.global_scale; let S = mat3x3( - exp(log_scale.x) * modifier, 0.0, 0.0, - 0.0, exp(log_scale.y) * modifier, 0.0, - 0.0, 0.0, exp(log_scale.z) * modifier, + scale.x * modifier, 0.0, 0.0, + 0.0, scale.y * modifier, 0.0, + 0.0, 0.0, scale.z * modifier, ); let r = rot.x; @@ -70,13 +67,16 @@ fn compute_cov3d(log_scale: vec3, rot: vec4) -> array { ); } -fn compute_cov2d(position: vec3, log_scale: vec3, rot: vec4) -> vec3 { - let cov3d = compute_cov3d(log_scale, rot); +fn compute_cov2d(position: vec3, scale: vec3, rot: vec4) -> vec3 { + let cov3d = compute_cov3d(scale, rot); - var t = uniforms.viewMatrix * vec4(position, 1.0); + var t = view.inverse_view * vec4(position, 1.0); - let limx = 1.3 * uniforms.tan_fovx; - let limy = 1.3 * uniforms.tan_fovy; + let focal_x = 500.0; + let focal_y = 500.0; + + let limx = 1.3 * 0.5 * view.viewport.z / focal_x; + let limy = 1.3 * 0.5 * view.viewport.w / focal_y; let txtz = t.x / t.z; let tytz = t.y / t.z; @@ -84,13 +84,13 @@ fn compute_cov2d(position: vec3, log_scale: vec3, rot: vec4) -> v t.y = min(limy, max(-limy, tytz)) * t.z; let J = mat4x4( - uniforms.focal_x / t.z, 0.0, -(uniforms.focal_x * t.x) / (t.z * t.z), 0.0, - 0.0, uniforms.focal_y / t.z, -(uniforms.focal_y * t.y) / (t.z * t.z), 0.0, + focal_x / t.z, 0.0, -(focal_x * t.x) / (t.z * t.z), 0.0, + 0.0, focal_y / t.z, -(focal_y * t.y) / (t.z * t.z), 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0 ); - let W = transpose(uniforms.viewMatrix); + let W = transpose(view.inverse_view); let T = W * J; @@ -105,70 +105,278 @@ fn compute_cov2d(position: vec3, log_scale: vec3, rot: vec4) -> v // Apply low-pass filter: every Gaussian should be at least // one pixel wide/high. Discard 3rd row and column. - cov[0][0] += 0.3; - cov[1][1] += 0.3; + // cov[0][0] += 0.3; + // cov[1][1] += 0.3; return vec3(cov[0][0], cov[0][1], cov[1][1]); } -@binding(0) @group(0) var uniforms: SceneUniforms; -@binding(1) @group(1) var points: array; -const quadVertices = array, 6>( - vec2(-1.0, -1.0), - vec2(-1.0, 1.0), - vec2(1.0, -1.0), - vec2(1.0, 1.0), - vec2(-1.0, 1.0), - vec2(1.0, -1.0), -); +// https://github.com/Lichtso/splatter/blob/c6b7a3894c25578cd29c9761619e4f194449e389/src/shaders.wgsl#L125-L169 +fn quat_to_mat(p: vec4) -> mat3x3 { + var q = p * sqrt(2.0); + var yy = q.y * q.y; + var yz = q.y * q.z; + var yw = q.y * q.w; + var yx = q.y * q.x; + var zz = q.z * q.z; + var zw = q.z * q.w; + var zx = q.z * q.x; + var ww = q.w * q.w; + var wx = q.w * q.x; + return mat3x3( + 1.0 - zz - ww, yz + wx, yw - zx, + yz - wx, 1.0 - yy - ww, zw + yx, + yw + zx, zw - yx, 1.0 - yy - zz, + ); +} + +fn projected_covariance_of_ellipsoid(scale: vec3, rotation: vec4, translation: vec3) -> mat3x3 { + let camera_matrix = mat3x3( + view.view.x.xyz, + view.view.y.xyz, + view.view.z.xyz + ); + var transform = quat_to_mat(rotation); + transform.x *= scale.x; + transform.y *= scale.y; + transform.z *= scale.z; + + // 3D Covariance + var view_pos = view.view * vec4(translation, 1.0); + view_pos.x = clamp(view_pos.x / view_pos.z, -1.0, 1.0) * view_pos.z; + view_pos.y = clamp(view_pos.y / view_pos.z, -1.0, 1.0) * view_pos.z; + let T = transpose(transform) * camera_matrix * mat3x3( + 1.0 / view_pos.z, 0.0, -view_pos.x / (view_pos.z * view_pos.z), + 0.0, 1.0 / view_pos.z, -view_pos.y / (view_pos.z * view_pos.z), + 0.0, 0.0, 0.0, + ); + let covariance_matrix = transpose(T) * T; + + return covariance_matrix; +} + +fn projected_contour_of_ellipsoid(scale: vec3, rotation: vec4, translation: vec3) -> mat3x3 { + let camera_matrix = mat3x3( + view.inverse_view.x.xyz, + view.inverse_view.y.xyz, + view.inverse_view.z.xyz + ); + + var transform = quat_to_mat(rotation); + transform.x /= scale.x; + transform.y /= scale.y; + transform.z /= scale.z; + + let ray_origin = view.world_position - translation; + let local_ray_origin = ray_origin * transform; + let local_ray_origin_squared = local_ray_origin * local_ray_origin; + + let diagonal = 1.0 - local_ray_origin_squared.yxx - local_ray_origin_squared.zzy; + let triangle = local_ray_origin.yxx * local_ray_origin.zzy; + + let A = mat3x3( + diagonal.x, triangle.z, triangle.y, + triangle.z, diagonal.y, triangle.x, + triangle.y, triangle.x, diagonal.z, + ); + + transform = transpose(camera_matrix) * transform; + let M = transform * A * transpose(transform); + + return M; +} + +fn extract_translation_of_ellipse(M: mat3x3) -> vec2 { + let discriminant = M.x.x * M.y.y - M.x.y * M.x.y; + let inverse_discriminant = 1.0 / discriminant; + return vec2( + M.x.y * M.y.z - M.y.y * M.x.z, + M.x.y * M.x.z - M.x.x * M.y.z, + ) * inverse_discriminant; +} + +fn extract_rotation_of_ellipse(M: mat3x3) -> vec2 { + let a = (M.x.x - M.y.y) * (M.x.x - M.y.y); + let b = a + 4.0 * M.x.y * M.x.y; + let c = 0.5 * sqrt(a / b); + var j = sqrt(0.5 - c); + var k = -sqrt(0.5 + c) * sign(M.x.y) * sign(M.x.x - M.y.y); + if(M.x.y < 0.0 || M.x.x - M.y.y < 0.0) { + k = -k; + j = -j; + } + if(M.x.x - M.y.y < 0.0) { + let t = j; + j = -k; + k = t; + } + return vec2(j, k); +} + +fn extract_scale_of_ellipse(M: mat3x3, translation: vec2, rotation: vec2) -> vec2 { + let d = 2.0 * M.x.y * rotation.x * rotation.y; + let e = M.z.z - (M.x.x * translation.x * translation.x + M.y.y * translation.y * translation.y + 2.0 * M.x.y * translation.x * translation.y); + let semi_major_axis = sqrt(abs(e / (M.x.x * rotation.y * rotation.y + M.y.y * rotation.x * rotation.x - d))); + let semi_minor_axis = sqrt(abs(e / (M.x.x * rotation.x * rotation.x + M.y.y * rotation.y * rotation.y + d))); + + return vec2(semi_major_axis, semi_minor_axis); +} + +fn extract_scale_of_covariance(M: mat3x3) -> vec2 { + let a = (M.x.x - M.y.y) * (M.x.x - M.y.y); + let b = sqrt(a + 4.0 * M.x.y * M.x.y); + let semi_major_axis = sqrt((M.x.x + M.y.y + b) * 0.5); + let semi_minor_axis = sqrt((M.x.x + M.y.y - b) * 0.5); + return vec2(semi_major_axis, semi_minor_axis); +} + +fn world_to_clip(world_pos: vec3) -> vec4 { + let homogenous_pos = view.view_proj * vec4(world_pos, 1.0); + return vec4(homogenous_pos.xyz, 1.0) / (homogenous_pos.w + 0.0000001); +} + +fn in_frustum(clip_space_pos: vec3) -> bool { + return abs(clip_space_pos.x) < 1.1 + && abs(clip_space_pos.y) < 1.1 + && abs(clip_space_pos.z - 0.5) < 0.5; +} + +fn view_dimensions(projection: mat4x4) -> vec2 { + let near = projection[2][3] / (projection[2][2] + 1.0); + let right = near / projection[0][0]; + let top = near / projection[1][1]; + + return vec2(2.0 * right, 2.0 * top); +} + @vertex -fn vs_points(@builtin(vertex_index) vertex_index: u32) -> GaussianOutput { +fn vs_points( + @builtin(instance_index) instance_index: u32, + @builtin(vertex_index) vertex_index: u32, +) -> GaussianOutput { var output: GaussianOutput; - let pointIndex = vertex_index / 6u; - let quadIndex = vertex_index % 6u; - let quadOffset = quadVertices[quadIndex]; - let point = points[pointIndex]; + let point = points[instance_index]; + + if (!in_frustum(world_to_clip(point.position).xyz)) { + output.color = vec4(0.0, 0.0, 0.0, 0.0); + return output; + } + + var quad_vertices = array, 4>( + vec2(-1.0, -1.0), + vec2(-1.0, 1.0), + vec2( 1.0, -1.0), + vec2( 1.0, 1.0), + ); + + let quad_index = vertex_index % 4u; + let quad_offset = quad_vertices[quad_index]; + + let ray_direction = normalize(point.position - view.world_position); + output.color = vec4( + spherical_harmonics_lookup(ray_direction, point.sh), + point.opacity + ); + + let cov2d = compute_cov2d(point.position, point.scale, point.rotation); - let cov2d = compute_cov2d(point.position, point.log_scale, point.rot); let det = cov2d.x * cov2d.z - cov2d.y * cov2d.y; let det_inv = 1.0 / det; - let conic = vec3(cov2d.z * det_inv, -cov2d.y * det_inv, cov2d.x * det_inv); + + let conic = vec3( + cov2d.z * det_inv, + -cov2d.y * det_inv, + cov2d.x * det_inv + ); + output.conic = conic; + let mid = 0.5 * (cov2d.x + cov2d.z); let lambda_1 = mid + sqrt(max(0.1, mid * mid - det)); let lambda_2 = mid - sqrt(max(0.1, mid * mid - det)); let radius_px = ceil(3.0 * sqrt(max(lambda_1, lambda_2))); let radius_ndc = vec2( - radius_px / (canvas_height), - radius_px / (canvas_width), + radius_px / f32(view.viewport.z), + radius_px / f32(view.viewport.w), + ); + + output.uv = radius_px * quad_offset; + + var projected_position = view.view_proj * vec4(point.position, 1.0); + projected_position = projected_position / projected_position.w; + + output.position = vec4( + projected_position.xy + 2.0 * radius_ndc * quad_offset, + projected_position.zw, ); - output.conic_and_opacity = vec4(conic, sigmoid(point.opacity_logit)); - var projPosition = uniforms.projMatrix * vec4(point.position, 1.0); - projPosition = projPosition / projPosition.w; - output.position = vec4(projPosition.xy + 2 * radius_ndc * quadOffset, projPosition.zw); - output.color = compute_color_from_sh_3_degree(point.position, point.sh); - output.uv = radius_px * quadOffset; + // let M = projected_contour_of_ellipsoid( + // point.scale * uniforms.global_scale, + // point.rotation, + // point.position, + // ); + // let translation = extract_translation_of_ellipse(M); + // let rotation = extract_rotation_of_ellipse(M); + // //let semi_axes = extract_scale_of_ellipse(M, translation, rotation); + + // let covariance = projected_covariance_of_ellipsoid( + // point.scale * uniforms.global_scale, + // point.rotation, + // point.position + // ); + // let semi_axes = extract_scale_of_covariance(covariance); + + // let view_dimensions = view_dimensions(view.projection); + // let ellipse_size_bias = 0.2 * view_dimensions.x / f32(view.viewport.z); + + // let transformation = mat3x2( + // vec2(rotation.y, -rotation.x) * (ellipse_size_bias + semi_axes.x), + // vec2(rotation.x, rotation.y) * (ellipse_size_bias + semi_axes.y), + // translation, + // ); + + // let T = mat3x3( + // vec3(transformation.x, 0.0), + // vec3(transformation.y, 0.0), + // vec3(transformation.z, 1.0), + // ); + + // let ellipse_margin = 3.3; // should be 2.0 + // output.uv = quad_offset * ellipse_margin; + // output.position = vec4( + // (T * vec3(output.uv, 1.0)).xy / view_dimensions, + // 0.0, + // 1.0, + // ); return output; } @fragment -fn fs_main(input: PointOutput) -> @location(0) vec4 { - // we want the distance from the gaussian to the fragment while uv - // is the reverse +fn fs_main(input: GaussianOutput) -> @location(0) vec4 { + // let power = dot(input.uv, input.uv); + // let alpha = input.color.a * exp(-0.5 * power); + + // if (alpha < 1.0 / 255.0) { + // discard; + // } + + // return vec4(input.color.rgb * alpha, alpha); + + let d = -input.uv; - let conic = input.conic_and_opacity.xyz; + let conic = input.conic; let power = -0.5 * (conic.x * d.x * d.x + conic.z * d.y * d.y) + conic.y * d.x * d.y; - let opacity = input.conic_and_opacity.w; if (power > 0.0) { discard; } - let alpha = min(0.99, opacity * exp(power)); - - return vec4(input.color * alpha, alpha); + let alpha = min(0.99, input.color.a * exp(power)); + return vec4( + input.color.rgb * alpha, + alpha, + ); } diff --git a/src/render/mod.rs b/src/render/mod.rs index 26b53dcc..ac402b27 100644 --- a/src/render/mod.rs +++ b/src/render/mod.rs @@ -1,16 +1,75 @@ +use std::hash::Hash; + use bevy::{ + prelude::*, asset::{ load_internal_asset, HandleUntyped, + LoadState, + }, + core_pipeline::core_3d::Transparent3d, + ecs::{ + system::{ + lifetimeless::*, + SystemParamItem, + }, + query::ROQueryItem, }, - prelude::*, reflect::TypeUuid, + render::{ + Extract, + extract_component::{ + DynamicUniformIndex, + UniformComponentPlugin, + ComponentUniforms, + }, + globals::{ + GlobalsUniform, + GlobalsBuffer, + }, + mesh::GpuBufferInfo, + render_asset::{ + PrepareAssetError, + RenderAsset, + RenderAssets, + RenderAssetPlugin, + }, + render_phase::{ + AddRenderCommand, + DrawFunctions, + PhaseItem, + RenderCommand, + RenderCommandResult, + RenderPhase, + SetItemPipeline, + TrackedRenderPass, + }, + render_resource::*, + renderer::RenderDevice, + Render, + RenderApp, + RenderSet, + view::{ + ExtractedView, + ViewUniform, + ViewUniforms, + ViewUniformOffset, + }, + }, +}; + +use crate::gaussian::{ + Gaussian, + GaussianCloud, + GaussianCloudSettings, + MAX_SH_COEFF_COUNT, }; const GAUSSIAN_SHADER_HANDLE: HandleUntyped = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 68294581); const SPHERICAL_HARMONICS_SHADER_HANDLE: HandleUntyped = HandleUntyped::weak_from_u64(Shader::TYPE_UUID, 834667312); + #[derive(Default)] pub struct RenderPipelinePlugin; @@ -29,8 +88,545 @@ impl Plugin for RenderPipelinePlugin { "spherical_harmonics.wgsl", Shader::from_wgsl ); + + app.add_plugins(RenderAssetPlugin::::default()); + app.add_plugins(UniformComponentPlugin::::default()); + + if let Ok(render_app) = app.get_sub_app_mut(RenderApp) { + render_app + .add_render_command::() + .init_resource::() + .add_systems(ExtractSchedule, extract_gaussians) + .add_systems( + Render, + ( + queue_gaussian_bind_group.in_set(RenderSet::Queue), + queue_gaussian_view_bind_groups.in_set(RenderSet::Queue), + queue_gaussians.in_set(RenderSet::Queue), + ), + ); + } + } + + fn finish(&self, app: &mut App) { + if let Ok(render_app) = app.get_sub_app_mut(RenderApp) { + render_app + .init_resource::() + .init_resource::>(); + } + } +} + + +#[derive(Bundle)] +pub struct GpuGaussianSplattingBundle { + pub settings_uniform: GaussianCloudUniform, + pub verticies: Handle, +} + +#[derive(Debug, Clone)] +pub struct GpuGaussianCloud { + pub buffer: Buffer, + pub count: u32, + pub buffer_info: GpuBufferInfo, +} +impl RenderAsset for GaussianCloud { + type ExtractedAsset = GaussianCloud; + type PreparedAsset = GpuGaussianCloud; + type Param = SRes; + + fn extract_asset(&self) -> Self::ExtractedAsset { + self.clone() + } + + fn prepare_asset( + gaussian_cloud: Self::ExtractedAsset, + render_device: &mut SystemParamItem, + ) -> Result> { + let buffer = render_device.create_buffer_with_data(&BufferInitDescriptor { + label: Some("gaussian cloud buffer"), + contents: bytemuck::cast_slice(gaussian_cloud.0.as_slice()), + usage: BufferUsages::VERTEX | BufferUsages::COPY_DST | BufferUsages::STORAGE, + }); + + Ok(GpuGaussianCloud { + buffer, + count: gaussian_cloud.0.len() as u32, + buffer_info: GpuBufferInfo::NonIndexed, + }) + } +} + + +#[allow(clippy::too_many_arguments)] +fn queue_gaussians( + transparent_3d_draw_functions: Res>, + custom_pipeline: Res, + mut pipelines: ResMut>, + pipeline_cache: Res, + gaussian_clouds: Res>, + gaussian_splatting_bundles: Query<( + Entity, + &Handle, + )>, + mut views: Query<(&ExtractedView, &mut RenderPhase)>, +) { + let draw_custom = transparent_3d_draw_functions.read().id::(); + + for (_view, mut transparent_phase) in &mut views { + for (entity, verticies) in &gaussian_splatting_bundles { + if let Some(_cloud) = gaussian_clouds.get(verticies) { + let key = GaussianCloudPipelineKey { + + }; + + let pipeline = pipelines.specialize(&pipeline_cache, &custom_pipeline, key); + + transparent_phase.add(Transparent3d { + entity, + draw_function: draw_custom, + distance: 0.0, + pipeline, + }); + } + } + } +} + + +#[derive(Resource)] +pub struct GaussianCloudPipeline { + shader: Handle, + pub gaussian_cloud_layout: BindGroupLayout, + pub gaussian_uniform_layout: BindGroupLayout, + pub view_layout: BindGroupLayout, +} + +impl FromWorld for GaussianCloudPipeline { + fn from_world(render_world: &mut World) -> Self { + let render_device = render_world.resource::(); + + let view_layout_entries = vec![ + BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::VERTEX_FRAGMENT, + ty: BindingType::Buffer { + ty: BufferBindingType::Uniform, + has_dynamic_offset: true, + min_binding_size: Some(ViewUniform::min_size()), + }, + count: None, + }, + BindGroupLayoutEntry { + binding: 1, + visibility: ShaderStages::VERTEX_FRAGMENT, + ty: BindingType::Buffer { + ty: BufferBindingType::Uniform, + has_dynamic_offset: false, + min_binding_size: Some(GlobalsUniform::min_size()), + }, + count: None, + }, + ]; + + let view_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("gaussian_view_layout"), + entries: &view_layout_entries, + }); + + let gaussian_uniform_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("gaussian_uniform_layout"), + entries: &vec![ + BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::VERTEX_FRAGMENT, + ty: BindingType::Buffer { + ty: BufferBindingType::Uniform, + has_dynamic_offset: true, + min_binding_size: Some(GaussianCloudUniform::min_size()), + }, + count: None, + }, + ], + }); + + let gaussian_cloud_layout = render_device.create_bind_group_layout(&BindGroupLayoutDescriptor { + label: Some("gaussian_cloud_layout"), + entries: &vec![ + BindGroupLayoutEntry { + binding: 0, + visibility: ShaderStages::VERTEX_FRAGMENT, + ty: BindingType::Buffer { + ty: BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: BufferSize::new(std::mem::size_of::() as u64), + }, + count: None, + }, + ], + }); + + GaussianCloudPipeline { + gaussian_cloud_layout, + gaussian_uniform_layout, + view_layout, + shader: GAUSSIAN_SHADER_HANDLE.typed(), + } + } +} + +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +pub struct GaussianCloudPipelineKey { + +} + +impl SpecializedRenderPipeline for GaussianCloudPipeline { + type Key = GaussianCloudPipelineKey; + + fn specialize(&self, _key: Self::Key) -> RenderPipelineDescriptor { + let shader_defs = vec![ + ShaderDefVal::UInt("MAX_SH_COEFF_COUNT".into(), MAX_SH_COEFF_COUNT as u32), + ]; + + RenderPipelineDescriptor { + label: Some("gaussian cloud pipeline".into()), + layout: vec![ + self.view_layout.clone(), + self.gaussian_uniform_layout.clone(), + self.gaussian_cloud_layout.clone(), + ], + vertex: VertexState { + shader: self.shader.clone(), + shader_defs: shader_defs.clone(), + entry_point: "vs_points".into(), + buffers: vec![], + }, + fragment: Some(FragmentState { + shader: self.shader.clone(), + shader_defs, + entry_point: "fs_main".into(), + targets: vec![Some(ColorTargetState { + format: TextureFormat::Rgba8UnormSrgb, + blend: Some(BlendState { + color: BlendComponent { + src_factor: BlendFactor::DstAlpha, + dst_factor: BlendFactor::One, + operation: BlendOperation::Add, + }, + alpha: BlendComponent { + src_factor: BlendFactor::Zero, + dst_factor: BlendFactor::OneMinusSrcAlpha, + operation: BlendOperation::Add, + }, + }), + write_mask: ColorWrites::ALL, + })], + }), + primitive: PrimitiveState { + topology: PrimitiveTopology::TriangleStrip, + strip_index_format: None, + front_face: FrontFace::Ccw, + unclipped_depth: false, + cull_mode: None, + conservative: false, + polygon_mode: PolygonMode::Fill, + }, + depth_stencil: Some(DepthStencilState { + format: TextureFormat::Depth32Float, + depth_write_enabled: false, + depth_compare: CompareFunction::GreaterEqual, + stencil: StencilState { + front: StencilFaceState::IGNORE, + back: StencilFaceState::IGNORE, + read_mask: 0, + write_mask: 0, + }, + bias: DepthBiasState { + constant: 0, + slope_scale: 0.0, + clamp: 0.0, + }, + }), + multisample: MultisampleState { + count: 4, + mask: !0, + alpha_to_coverage_enabled: false, + }, + push_constant_ranges: Vec::new(), + } } } +type DrawGaussians = ( + SetItemPipeline, + SetGaussianViewBindGroup<0>, + SetGaussianUniformBindGroup<1>, + DrawGaussianInstanced, +); -// TODO: implement gaussian cloud render pipeline + +#[derive(Component, ShaderType, Clone)] +pub struct GaussianCloudUniform { + pub global_scale: f32, + pub transform: Mat4, +} + +pub fn extract_gaussians( + mut commands: Commands, + mut prev_commands_len: Local, + gaussians_query: Extract< + Query<( + Entity, + // &ComputedVisibility, + &GaussianCloudSettings, + &Handle, + )>, + >, +) { + let mut commands_list = Vec::with_capacity(*prev_commands_len); + // let visible_gaussians = gaussians_query.iter().filter(|(_, vis, ..)| vis.is_visible()); + + for (entity, settings, verticies) in gaussians_query.iter() { + let settings_uniform = GaussianCloudUniform { + global_scale: settings.global_scale, + transform: settings.global_transform.compute_matrix(), + }; + commands_list.push(( + entity, + GpuGaussianSplattingBundle { + settings_uniform, + verticies: verticies.clone_weak(), + }, + )); + } + *prev_commands_len = commands_list.len(); + commands.insert_or_spawn_batch(commands_list); +} + + +#[derive(Resource, Default)] +pub struct GaussianUniformBindGroups { + base_bind_group: Option, +} + +#[derive(Component)] +pub struct GaussianCloudBindGroup { + pub bind_group: BindGroup, +} + +pub fn queue_gaussian_bind_group( + mut commands: Commands, + mut groups: ResMut, + gaussian_cloud_pipeline: Res, + render_device: Res, + gaussian_uniforms: Res>, + asset_server: Res, + gaussian_cloud_res: Res>, + gaussian_clouds: Query<( + Entity, + &Handle, + )>, +) { + let Some(model) = gaussian_uniforms.buffer() else { + return; + }; + + assert!(model.size() == std::mem::size_of::() as u64); + + for (entity, cloud_handle) in gaussian_clouds.iter() { + if asset_server.get_load_state(cloud_handle) == LoadState::Loading { + continue; + } + + let cloud = gaussian_cloud_res.get(cloud_handle).unwrap(); + + groups.base_bind_group = Some(render_device.create_bind_group(&BindGroupDescriptor { + entries: &[ + BindGroupEntry { + binding: 0, + resource: BindingResource::Buffer(BufferBinding { + buffer: model, + offset: 0, + size: BufferSize::new(model.size()), + }), + }, + ], + layout: &gaussian_cloud_pipeline.gaussian_uniform_layout, + label: Some("gaussian_uniform_bind_group"), + })); + + commands.entity(entity).insert(GaussianCloudBindGroup { + bind_group: render_device.create_bind_group(&BindGroupDescriptor { + entries: &[ + BindGroupEntry { + binding: 0, + resource: BindingResource::Buffer(BufferBinding { + buffer: &cloud.buffer, + offset: 0, + size: BufferSize::new(cloud.buffer.size()), + }), + }, + ], + layout: &gaussian_cloud_pipeline.gaussian_cloud_layout, + label: Some("gaussian_cloud_bind_group"), + }), + }); + } +} + + +#[derive(Component)] +pub struct GaussianViewBindGroup { + pub value: BindGroup, +} + +pub fn queue_gaussian_view_bind_groups( + mut commands: Commands, + render_device: Res, + gaussian_cloud_pipeline: Res, + view_uniforms: Res, + views: Query<( + Entity, + &ExtractedView, + &mut RenderPhase, + )>, + globals_buffer: Res, +) { + if let ( + Some(view_binding), + Some(globals), + ) = ( + view_uniforms.uniforms.binding(), + globals_buffer.buffer.binding(), + ) { + for ( + entity, + _extracted_view, + _render_phase, + ) in &views + { + let layout = &gaussian_cloud_pipeline.view_layout; + + let entries = vec![ + BindGroupEntry { + binding: 0, + resource: view_binding.clone(), + }, + BindGroupEntry { + binding: 1, + resource: globals.clone(), + }, + ]; + + let view_bind_group = render_device.create_bind_group(&BindGroupDescriptor { + entries: &entries, + label: Some("gaussian_view_bind_group"), + layout, + }); + + + commands.entity(entity).insert(GaussianViewBindGroup { + value: view_bind_group, + }); + } + } +} + +pub struct SetGaussianViewBindGroup; +impl RenderCommand

for SetGaussianViewBindGroup { + type Param = (); + type ViewWorldQuery = ( + Read, + Read, + ); + type ItemWorldQuery = (); + + #[inline] + fn render<'w>( + _item: &P, + (view_uniform, gaussian_view_bind_group): ROQueryItem< + 'w, + Self::ViewWorldQuery, + >, + _entity: (), + _: SystemParamItem<'w, '_, Self::Param>, + pass: &mut TrackedRenderPass<'w>, + ) -> RenderCommandResult { + pass.set_bind_group( + I, + &gaussian_view_bind_group.value, + &[view_uniform.offset], + ); + + RenderCommandResult::Success + } +} + + +pub struct SetGaussianUniformBindGroup; +impl RenderCommand

for SetGaussianUniformBindGroup { + type Param = SRes; + type ViewWorldQuery = (); + type ItemWorldQuery = Read>; + + #[inline] + fn render<'w>( + _item: &P, + _view: (), + gaussian_cloud_index: ROQueryItem, + bind_groups: SystemParamItem<'w, '_, Self::Param>, + pass: &mut TrackedRenderPass<'w>, + ) -> RenderCommandResult { + let bind_groups = bind_groups.into_inner(); + let bind_group = bind_groups.base_bind_group.as_ref().expect("bind group not initialized"); + + let mut set_bind_group = |indices: &[u32]| pass.set_bind_group(I, bind_group, indices); + let gaussian_cloud_index = gaussian_cloud_index.index(); + set_bind_group(&[gaussian_cloud_index]); + + RenderCommandResult::Success + } +} + +pub struct DrawGaussianInstanced; +impl RenderCommand

for DrawGaussianInstanced { + type Param = SRes>; + type ViewWorldQuery = (); + type ItemWorldQuery = ( + Read>, + Read, + ); + + #[inline] + fn render<'w>( + _item: &P, + _view: (), + (handle, bind_group): (&'w Handle, &'w GaussianCloudBindGroup), + gaussian_clouds: SystemParamItem<'w, '_, Self::Param>, + pass: &mut TrackedRenderPass<'w>, + ) -> RenderCommandResult { + let gpu_gaussian_cloud = match gaussian_clouds.into_inner().get(handle) { + Some(gpu_gaussian_cloud) => gpu_gaussian_cloud, + None => return RenderCommandResult::Failure, + }; + + pass.set_bind_group(2, &bind_group.bind_group, &[]); + + match &gpu_gaussian_cloud.buffer_info { + GpuBufferInfo::Indexed { + buffer, + index_format, + count, + } => { + pass.set_index_buffer(buffer.slice(..), 0, *index_format); + pass.draw_indexed(0..*count, 0, 0..gpu_gaussian_cloud.count as u32); + } + GpuBufferInfo::NonIndexed => { + pass.draw(0..4, 0..gpu_gaussian_cloud.count as u32); + } + + // TODO: add support for indirect draw and match over sort methods + } + RenderCommandResult::Success + } +} diff --git a/src/render/spherical_harmonics.wgsl b/src/render/spherical_harmonics.wgsl index 4c6393b4..3daafa37 100644 --- a/src/render/spherical_harmonics.wgsl +++ b/src/render/spherical_harmonics.wgsl @@ -1,63 +1,51 @@ #define_import_path bevy_gaussian_splatting::spherical_harmonics -const SH_C0 = 0.28209479177387814f; -const SH_C1 = 0.4886025119029199f; -const SH_C2 = array( - 1.0925484305920792f, - -1.0925484305920792f, - 0.31539156525252005f, - -1.0925484305920792f, - 0.5462742152960396f -); -const SH_C3 = array( - -0.5900435899266435f, - 2.890611442640554f, - -0.4570457994644658f, - 0.3731763325901154f, - -0.4570457994644658f, - 1.445305721320277f, - -0.5900435899266435f +const shc = array( + 0.28209479177387814, + -0.4886025119029199, + 0.4886025119029199, + -0.4886025119029199, + 1.0925484305920792, + -1.0925484305920792, + 0.31539156525252005, + -1.0925484305920792, + 0.5462742152960396, + -0.5900435899266435, + 2.890611442640554, + -0.4570457994644658, + 0.3731763325901154, + -0.4570457994644658, + 1.445305721320277, + -0.5900435899266435, ); -fn compute_color_from_sh_3_degree(position: vec3, sh: array, 16>) -> vec3 { - let dir = normalize(position - uniforms.camera_position); - var result = SH_C0 * sh[0]; - - // if deg > 0 - let x = dir.x; - let y = dir.y; - let z = dir.z; - - result = result + SH_C1 * (-y * sh[1] + z * sh[2] - x * sh[3]); - - let xx = x * x; - let yy = y * y; - let zz = z * z; - let xy = x * y; - let xz = x * z; - let yz = y * z; - - // if (sh_degree > 1) { - result = result + - SH_C2[0] * xy * sh[4] + - SH_C2[1] * yz * sh[5] + - SH_C2[2] * (2. * zz - xx - yy) * sh[6] + - SH_C2[3] * xz * sh[7] + - SH_C2[4] * (xx - yy) * sh[8]; - - // if (sh_degree > 2) { - result = result + - SH_C3[0] * y * (3. * xx - yy) * sh[9] + - SH_C3[1] * xy * z * sh[10] + - SH_C3[2] * y * (4. * zz - xx - yy) * sh[11] + - SH_C3[3] * z * (2. * zz - 3. * xx - 3. * yy) * sh[12] + - SH_C3[4] * x * (4. * zz - xx - yy) * sh[13] + - SH_C3[5] * z * (xx - yy) * sh[14] + - SH_C3[6] * x * (xx - 3. * yy) * sh[15]; - - // unconditional - result = result + 0.5; - - return max(result, vec3(0.)); +fn spherical_harmonics_lookup( + ray_direction: vec3, + sh: array, +) -> vec3 { + var rds = ray_direction * ray_direction; + var color = vec3(0.5); + + color += shc[ 0] * vec3(sh[0], sh[1], sh[2]); + + color += shc[ 1] * vec3(sh[3], sh[4], sh[5]) * ray_direction.y; + color += shc[ 2] * vec3(sh[6], sh[7], sh[8]) * ray_direction.z; + color += shc[ 3] * vec3(sh[9], sh[10], sh[11]) * ray_direction.x; + + color += shc[ 4] * vec3(sh[12], sh[13], sh[14]) * ray_direction.x * ray_direction.y; + color += shc[ 5] * vec3(sh[15], sh[16], sh[17]) * ray_direction.y * ray_direction.z; + color += shc[ 6] * vec3(sh[18], sh[19], sh[20]) * (2.0 * rds.z - rds.x - rds.y); + color += shc[ 7] * vec3(sh[21], sh[22], sh[23]) * ray_direction.x * ray_direction.z; + color += shc[ 8] * vec3(sh[24], sh[25], sh[26]) * (rds.x - rds.y); + + color += shc[ 9] * vec3(sh[27], sh[28], sh[29]) * ray_direction.y * (3.0 * rds.x - rds.y); + color += shc[10] * vec3(sh[30], sh[31], sh[32]) * ray_direction.x * ray_direction.y * ray_direction.z; + color += shc[11] * vec3(sh[33], sh[34], sh[35]) * ray_direction.y * (4.0 * rds.z - rds.x - rds.y); + color += shc[12] * vec3(sh[36], sh[37], sh[38]) * ray_direction.z * (2.0 * rds.z - 3.0 * rds.x - 3.0 * rds.y); + color += shc[13] * vec3(sh[39], sh[40], sh[41]) * ray_direction.x * (4.0 * rds.z - rds.x - rds.y); + color += shc[14] * vec3(sh[42], sh[43], sh[44]) * ray_direction.z * (rds.x - rds.y); + color += shc[15] * vec3(sh[45], sh[46], sh[47]) * ray_direction.x * (rds.x - 3.0 * rds.y); + + return color; }