Skip to content

Commit

Permalink
feat: spherindrical constants
Browse files Browse the repository at this point in the history
  • Loading branch information
mosure committed Jan 26, 2024
1 parent 5dcac36 commit a23c334
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 3 deletions.
199 changes: 197 additions & 2 deletions src/material/spherindrical_harmonics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,78 @@ use serde::{
use half::f16;

use crate::material::spherical_harmonics::{
SH_CHANNELS,
SH_DEGREE,
};


const SH_4D_DEGREE_TIME: usize = 0;
const fn gcd(a: usize, b: usize) -> usize {
if b == 0 {
a
} else {
gcd(b, a % b)
}
}


pub const SH_4D_DEGREE_TIME: usize = 2;

pub const SH_4D_COEFF_COUNT_PER_CHANNEL: usize = (SH_DEGREE + 1).pow(2) * (SH_4D_DEGREE_TIME + 1);
pub const SH_4D_COEFF_COUNT: usize = (SH_4D_COEFF_COUNT_PER_CHANNEL * SH_CHANNELS + 3) & !3;

pub const HALF_SH_4D_COEFF_COUNT: usize = (SH_4D_COEFF_COUNT / 2 + 3) & !3;

// TODO: calculate POD_PLANE_COUNT for f16 and f32 based on a switch for HALF_SH_4D_COEFF_COUNT vs. SH_4D_COEFF_COUNT
pub const MAX_POD_U32_ARRAY_SIZE: usize = 32;
pub const POD_ARRAY_SIZE: usize = gcd(HALF_SH_4D_COEFF_COUNT, MAX_POD_U32_ARRAY_SIZE);
pub const POD_PLANE_COUNT: usize = HALF_SH_4D_COEFF_COUNT / POD_ARRAY_SIZE;

pub const WASTE: usize = POD_PLANE_COUNT * POD_ARRAY_SIZE - HALF_SH_4D_COEFF_COUNT;
static_assertions::const_assert_eq!(WASTE, 0);


#[cfg(feature = "f16")]
pub const SH_4D_VEC4_PLANES: usize = HALF_SH_4D_COEFF_COUNT / 4;
#[cfg(feature = "f32")]
pub const SH_4D_VEC4_PLANES: usize = SH_4D_COEFF_COUNT / 4;


const SPHERINDRICAL_HARMONICS_SHADER_HANDLE: Handle<Shader> = Handle::weak_from_u128(512346253);

pub struct SpherindricalHarmonicCoefficientsPlugin;
impl Plugin for SpherindricalHarmonicCoefficientsPlugin {
fn build(&self, app: &mut App) {
load_internal_asset!(
app,
SPHERINDRICAL_HARMONICS_SHADER_HANDLE,
"spherindrical_harmonics.wgsl",
Shader::from_wgsl
);
}
}


#[cfg(feature = "f16")]
#[derive(
Clone,
Copy,
Debug,
PartialEq,
Reflect,
ShaderType,
Pod,
Zeroable,
Serialize,
Deserialize,
)]
#[repr(C)]
pub struct SpherindricalHarmonicCoefficients {
#[reflect(ignore)]
#[serde(serialize_with = "coefficients_serializer", deserialize_with = "coefficients_deserializer")]
pub coefficients: [[u32; POD_ARRAY_SIZE]; POD_PLANE_COUNT],
}

#[cfg(feature = "f32")]
#[derive(
Clone,
Copy,
Expand All @@ -44,5 +108,136 @@ const SH_4D_DEGREE_TIME: usize = 0;
pub struct SpherindricalHarmonicCoefficients {
#[reflect(ignore)]
#[serde(serialize_with = "coefficients_serializer", deserialize_with = "coefficients_deserializer")]
pub coefficients: [u32; HALF_SH_COEFF_COUNT],
pub coefficients: [u32; SH_4D_COEFF_COUNT],
}


#[cfg(feature = "f16")]
impl Default for SpherindricalHarmonicCoefficients {
fn default() -> Self {
Self {
coefficients: [[0; POD_ARRAY_SIZE]; POD_PLANE_COUNT],
}
}
}

#[cfg(feature = "f32")]
impl Default for SpherindricalHarmonicCoefficients {
fn default() -> Self {
Self {
coefficients: [0.0; SH_4D_COEFF_COUNT],
}
}
}


impl SpherindricalHarmonicCoefficients {
#[cfg(feature = "f16")]
pub fn set(&mut self, index: usize, value: f32) {
let quantized = f16::from_f32(value).to_bits();
self.coefficients[index / 2] = match index % 2 {
0 => (self.coefficients[index / 2] & 0xffff0000) | (quantized as u32),
1 => (self.coefficients[index / 2] & 0x0000ffff) | ((quantized as u32) << 16),
_ => unreachable!(),
};
}

#[cfg(feature = "f32")]
pub fn set(&mut self, index: usize, value: f32) {
self.coefficients[index] = value;
}
}



#[cfg(feature = "f16")]
fn coefficients_serializer<S>(n: &[u32; HALF_SH_4D_COEFF_COUNT], s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = s.serialize_tuple(SH_4D_COEFF_COUNT)?;
for &x in n.iter() {
tup.serialize_element(&x)?;
}

tup.end()
}

#[cfg(feature = "f16")]
fn coefficients_deserializer<'de, D>(d: D) -> Result<[u32; HALF_SH_4D_COEFF_COUNT], D::Error>
where
D: serde::Deserializer<'de>,
{
struct CoefficientsVisitor;

impl<'de> serde::de::Visitor<'de> for CoefficientsVisitor {
type Value = [u32; HALF_SH_4D_COEFF_COUNT];

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("an array of floats")
}

fn visit_seq<A>(self, mut seq: A) -> Result<[u32; HALF_SH_4D_COEFF_COUNT], A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut coefficients = [0; HALF_SH_4D_COEFF_COUNT];

for (i, coefficient) in coefficients.iter_mut().enumerate().take(SH_4D_COEFF_COUNT) {
*coefficient = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(i, &self))?;
}
Ok(coefficients)
}
}

d.deserialize_tuple(SH_4D_COEFF_COUNT, CoefficientsVisitor)
}


#[cfg(feature = "f32")]
fn coefficients_serializer<S>(n: &[f32; SH_4D_COEFF_COUNT], s: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut tup = s.serialize_tuple(SH_4D_COEFF_COUNT)?;
for &x in n.iter() {
tup.serialize_element(&x)?;
}

tup.end()
}

#[cfg(feature = "f32")]
fn coefficients_deserializer<'de, D>(d: D) -> Result<[f32; SH_4D_COEFF_COUNT], D::Error>
where
D: serde::Deserializer<'de>,
{
struct CoefficientsVisitor;

impl<'de> serde::de::Visitor<'de> for CoefficientsVisitor {
type Value = [f32; SH_4D_COEFF_COUNT];

fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("an array of floats")
}

fn visit_seq<A>(self, mut seq: A) -> Result<[f32; SH_4D_COEFF_COUNT], A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let mut coefficients = [0.0; SH_4D_COEFF_COUNT];

for (i, coefficient) in coefficients.iter_mut().enumerate().take(SH_4D_COEFF_COUNT) {
*coefficient = seq
.next_element()?
.ok_or_else(|| serde::de::Error::invalid_length(i, &self))?;
}
Ok(coefficients)
}
}

d.deserialize_tuple(SH_4D_COEFF_COUNT, CoefficientsVisitor)
}

2 changes: 1 addition & 1 deletion src/render/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ pub struct GpuGaussianSplattingBundle {
pub settings: GaussianCloudSettings,
pub settings_uniform: GaussianCloudUniform,
pub sorted_entries: Handle<SortedEntries>,
pub cloud_handle: Handle<GaussianCloud>,
pub cloud_handle: Handle<GaussianCloud>, // TODO: handle 4d gaussian cloud
}

#[derive(Debug, Clone)]
Expand Down

0 comments on commit a23c334

Please sign in to comment.