diff --git a/CHANGELOG.md b/CHANGELOG.md index 3955b86b..5bdc5e17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- `AppReplicationExt::replicate_with` now accepts newly added `ReplicationFns`. +- `AppReplicationExt::replicate_with` now accepts newly added `ReplicationFns` and the function is now `unsafe` (it was never "safe" before, caller must ensure that used `C` can be passed to the serialization function). - Move `Replication` to `core` module. - Move all functions-related logic from `ReplicationRules` into a new `ReplicationFns`. - Rename `serialize_component` into `serialize` and move into `replication_fns` module. diff --git a/src/core/replication_fns.rs b/src/core/replication_fns.rs index 55e90c67..a62342a8 100644 --- a/src/core/replication_fns.rs +++ b/src/core/replication_fns.rs @@ -55,7 +55,7 @@ impl Default for ReplicationFns { } /// Signature of component serialization functions. -pub type SerializeFn = fn(Ptr, &mut Cursor>) -> bincode::Result<()>; +pub type SerializeFn = unsafe fn(Ptr, &mut Cursor>) -> bincode::Result<()>; /// Signature of component deserialization functions. pub type DeserializeFn = fn( @@ -121,12 +121,15 @@ impl ComponentFns { pub struct ComponentFnsId(usize); /// Default serialization function. -pub fn serialize( +/// +/// # Safety +/// +/// `C` must be the erased pointee type for this [`Ptr`]. +pub unsafe fn serialize( component: Ptr, cursor: &mut Cursor>, ) -> bincode::Result<()> { - // SAFETY: function called for registered `ComponentId`. - let component: &C = unsafe { component.deref() }; + let component: &C = component.deref(); DefaultOptions::new().serialize_into(cursor, component) } diff --git a/src/core/replication_rules.rs b/src/core/replication_rules.rs index 7e77c514..7b978504 100644 --- a/src/core/replication_rules.rs +++ b/src/core/replication_rules.rs @@ -7,7 +7,7 @@ use bevy::{ }; use serde::{de::DeserializeOwned, Serialize}; -use super::replication_fns::{self, ComponentFns, ComponentFnsId, ReplicationFns}; +use super::replication_fns::{ComponentFns, ComponentFnsId, ReplicationFns}; /// Replication functions for [`App`]. pub trait AppReplicationExt { @@ -26,7 +26,8 @@ pub trait AppReplicationExt { where C: Component + Serialize + DeserializeOwned, { - self.replicate_with::(ComponentFns::default_fns::()); + // SAFETY: Component is registered with the corresponding default serialization function. + unsafe { self.replicate_with::(ComponentFns::default_fns::()) }; self } @@ -39,7 +40,8 @@ pub trait AppReplicationExt { where C: Component + Serialize + DeserializeOwned + MapEntities, { - self.replicate_with::(ComponentFns::default_mapped_fns::()); + // SAFETY: Component is registered with the corresponding default serialization function. + unsafe { self.replicate_with::(ComponentFns::default_mapped_fns::()) }; self } @@ -49,6 +51,10 @@ pub trait AppReplicationExt { Can be used to customize how the component will be replicated or for components that don't implement [`Serialize`] or [`DeserializeOwned`]. + # Safety + + Caller must ensure that component `C` can be safely passed to [`ComponentFns::serialize`]. + # Examples ``` @@ -63,16 +69,22 @@ pub trait AppReplicationExt { # let mut app = App::new(); # app.add_plugins(RepliconPlugins); - app.replicate_with::(ComponentFns { - serialize: serialize_translation, - deserialize: deserialize_translation, - remove: replication_fns::remove::, - }); + // SAFETY: `serialize_translation` expects `Transform`. + unsafe { + app.replicate_with::(ComponentFns { + serialize: serialize_translation, + deserialize: deserialize_translation, + remove: replication_fns::remove::, + }); + } /// Serializes only `translation` from [`Transform`]. - fn serialize_translation(component: Ptr, cursor: &mut Cursor>) -> bincode::Result<()> { - // SAFETY: function called for registered `ComponentId`. - let transform: &Transform = unsafe { component.deref() }; + /// + /// # Safety + /// + /// [`Transform`] must be the erased pointee type for this [`Ptr`]. + unsafe fn serialize_translation(component: Ptr, cursor: &mut Cursor>) -> bincode::Result<()> { + let transform: &Transform = component.deref(); bincode::serialize_into(cursor, &transform.translation) } @@ -90,10 +102,10 @@ pub trait AppReplicationExt { } ``` - The [`remove`](replication_fns::remove) used in this example is the default component + The [`remove`](super::replication_fns::remove) used in this example is the default component removal function, but you can replace it with your own as well. */ - fn replicate_with(&mut self, component_fns: ComponentFns) -> &mut Self + unsafe fn replicate_with(&mut self, component_fns: ComponentFns) -> &mut Self where C: Component; @@ -140,7 +152,7 @@ pub trait AppReplicationExt { } impl AppReplicationExt for App { - fn replicate_with(&mut self, component_fns: ComponentFns) -> &mut Self + unsafe fn replicate_with(&mut self, component_fns: ComponentFns) -> &mut Self where C: Component, { @@ -148,8 +160,8 @@ impl AppReplicationExt for App { let mut replication_fns = self.world.resource_mut::(); let fns_id = replication_fns.register_component_fns(component_fns); - let mut rules = self.world.resource_mut::(); - rules.insert(ReplicationRule::new(vec![(component_id, fns_id)])); + let rule = ReplicationRule::new(vec![(component_id, fns_id)]); + self.world.resource_mut::().insert(rule); self } @@ -189,18 +201,28 @@ pub struct ReplicationRule { pub priority: usize, /// Rule components and their serialization/deserialization/removal functions. - pub components: Vec<(ComponentId, ComponentFnsId)>, + components: Vec<(ComponentId, ComponentFnsId)>, } impl ReplicationRule { /// Creates a new rule with priority equal to the number of serialized components. - pub fn new(components: Vec<(ComponentId, ComponentFnsId)>) -> Self { + /// + /// # Safety + /// + /// Caller must ensure that in each pair the associated component can be safely + /// passed to the associated [`ComponentFns::serialize`]. + pub unsafe fn new(components: Vec<(ComponentId, ComponentFnsId)>) -> Self { Self { priority: components.len(), components, } } + /// Returns associated components and functions IDs. + pub(crate) fn components(&self) -> &[(ComponentId, ComponentFnsId)] { + &self.components + } + /// Determines whether an archetype contains all components required by the rule. pub(crate) fn matches(&self, archetype: &Archetype) -> bool { self.components @@ -293,7 +315,8 @@ impl GroupReplication for PlayerBundle { (visibility_id, visibility_fns_id), ]; - ReplicationRule::new(components) + // SAFETY: all components can be safely passed to their serialization functions. + unsafe { ReplicationRule::new(components) } } } @@ -314,15 +337,12 @@ macro_rules! impl_registrations { let mut components = Vec::new(); $( let component_id = world.init_component::<$type>(); - let fns_id = replication_fns.register_component_fns(ComponentFns { - serialize: replication_fns::serialize::<$type>, - deserialize: replication_fns::deserialize::<$type>, - remove: replication_fns::remove::<$type>, - }); + let fns_id = replication_fns.register_component_fns(ComponentFns::default_fns::<$type>()); components.push((component_id, fns_id)); )* - ReplicationRule::new(components) + // SAFETY: Components are registered with the appropriate default serialization functions. + unsafe { ReplicationRule::new(components) } } } } diff --git a/src/scene.rs b/src/scene.rs index 543eeba3..8782b3b2 100644 --- a/src/scene.rs +++ b/src/scene.rs @@ -74,7 +74,7 @@ pub fn replicate_into(scene: &mut DynamicScene, world: &World) { } for rule in rules.iter().filter(|rule| rule.matches(archetype)) { - for &(component_id, _) in &rule.components { + for &(component_id, _) in rule.components() { // SAFETY: replication rules can be registered only with valid component IDs. let replicated_component = unsafe { world.components().get_info_unchecked(component_id) }; diff --git a/src/server/removal_buffer.rs b/src/server/removal_buffer.rs index 05fda02d..79db6dcc 100644 --- a/src/server/removal_buffer.rs +++ b/src/server/removal_buffer.rs @@ -130,7 +130,7 @@ impl FromWorld for ReplicatedComponents { let rules = world.resource::(); let component_ids = rules .iter() - .flat_map(|rule| &rule.components) + .flat_map(|rule| rule.components()) .map(|&(component_id, _)| component_id) .collect(); @@ -172,7 +172,7 @@ impl RemovalBuffer { .iter() .filter(|rule| rule.matches_removals(archetype, components)) { - for &(component_id, fns_id) in &rule.components { + for &(component_id, fns_id) in rule.components() { // Since rules are sorted by priority, // we are inserting only new components that aren't present. if removed_ids diff --git a/src/server/replicated_archetypes.rs b/src/server/replicated_archetypes.rs index b803cb92..cdfef7fb 100644 --- a/src/server/replicated_archetypes.rs +++ b/src/server/replicated_archetypes.rs @@ -47,7 +47,7 @@ impl ReplicatedArchetypes { { let mut replicated_archetype = ReplicatedArchetype::new(archetype.id()); for rule in rules.iter().filter(|rule| rule.matches(archetype)) { - for &(component_id, fns_id) in &rule.components { + for &(component_id, fns_id) in rule.components() { // Since rules are sorted by priority, // we are inserting only new components that aren't present. if replicated_archetype diff --git a/src/server/replication_messages.rs b/src/server/replication_messages.rs index 23e11d01..1ed4bf58 100644 --- a/src/server/replication_messages.rs +++ b/src/server/replication_messages.rs @@ -299,7 +299,9 @@ impl InitMessage { let size = write_with(shared_bytes, &mut self.cursor, |cursor| { DefaultOptions::new().serialize_into(&mut *cursor, &fns_id)?; - (fns.serialize)(ptr, cursor) + // SAFETY: User ensured that the registered component can be + // safely passed to its serialization function. + unsafe { (fns.serialize)(ptr, cursor) } })?; self.entity_data_size = self @@ -525,7 +527,9 @@ impl UpdateMessage { let size = write_with(shared_bytes, &mut self.cursor, |cursor| { DefaultOptions::new().serialize_into(&mut *cursor, &fns_id)?; - (fns.serialize)(ptr, cursor) + // SAFETY: User ensured that the registered component can be + // safely passed to its serialization function. + unsafe { (fns.serialize)(ptr, cursor) } })?; self.entity_data_size = self