Skip to content

Commit

Permalink
Mark serializtion function as unsafe and related functions (#226)
Browse files Browse the repository at this point in the history
Caller must ensure that used `C` can be passed to the serialization
function.

Co-authored-by: UkoeHB <[email protected]>
  • Loading branch information
Shatur and UkoeHB authored Apr 3, 2024
1 parent b9596b7 commit 70a3bdc
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 36 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- `AppReplicationExt::replicate_with` now accepts newly added `ReplicationFns`.
- `AppReplicationExt::replicate_with` now accepts newly added `ReplicationFns` and the function is now `unsafe` (it was never "safe" before, caller must ensure that used `C` can be passed to the serialization function).
- Move `Replication` to `core` module.
- Move all functions-related logic from `ReplicationRules` into a new `ReplicationFns`.
- Rename `serialize_component` into `serialize` and move into `replication_fns` module.
Expand Down
11 changes: 7 additions & 4 deletions src/core/replication_fns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl Default for ReplicationFns {
}

/// Signature of component serialization functions.
pub type SerializeFn = fn(Ptr, &mut Cursor<Vec<u8>>) -> bincode::Result<()>;
pub type SerializeFn = unsafe fn(Ptr, &mut Cursor<Vec<u8>>) -> bincode::Result<()>;

/// Signature of component deserialization functions.
pub type DeserializeFn = fn(
Expand Down Expand Up @@ -121,12 +121,15 @@ impl ComponentFns {
pub struct ComponentFnsId(usize);

/// Default serialization function.
pub fn serialize<C: Component + Serialize>(
///
/// # Safety
///
/// `C` must be the erased pointee type for this [`Ptr`].
pub unsafe fn serialize<C: Component + Serialize>(
component: Ptr,
cursor: &mut Cursor<Vec<u8>>,
) -> bincode::Result<()> {
// SAFETY: function called for registered `ComponentId`.
let component: &C = unsafe { component.deref() };
let component: &C = component.deref();
DefaultOptions::new().serialize_into(cursor, component)
}

Expand Down
70 changes: 45 additions & 25 deletions src/core/replication_rules.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use bevy::{
};
use serde::{de::DeserializeOwned, Serialize};

use super::replication_fns::{self, ComponentFns, ComponentFnsId, ReplicationFns};
use super::replication_fns::{ComponentFns, ComponentFnsId, ReplicationFns};

/// Replication functions for [`App`].
pub trait AppReplicationExt {
Expand All @@ -26,7 +26,8 @@ pub trait AppReplicationExt {
where
C: Component + Serialize + DeserializeOwned,
{
self.replicate_with::<C>(ComponentFns::default_fns::<C>());
// SAFETY: Component is registered with the corresponding default serialization function.
unsafe { self.replicate_with::<C>(ComponentFns::default_fns::<C>()) };
self
}

Expand All @@ -39,7 +40,8 @@ pub trait AppReplicationExt {
where
C: Component + Serialize + DeserializeOwned + MapEntities,
{
self.replicate_with::<C>(ComponentFns::default_mapped_fns::<C>());
// SAFETY: Component is registered with the corresponding default serialization function.
unsafe { self.replicate_with::<C>(ComponentFns::default_mapped_fns::<C>()) };
self
}

Expand All @@ -49,6 +51,10 @@ pub trait AppReplicationExt {
Can be used to customize how the component will be replicated or
for components that don't implement [`Serialize`] or [`DeserializeOwned`].
# Safety
Caller must ensure that component `C` can be safely passed to [`ComponentFns::serialize`].
# Examples
```
Expand All @@ -63,16 +69,22 @@ pub trait AppReplicationExt {
# let mut app = App::new();
# app.add_plugins(RepliconPlugins);
app.replicate_with::<Transform>(ComponentFns {
serialize: serialize_translation,
deserialize: deserialize_translation,
remove: replication_fns::remove::<Transform>,
});
// SAFETY: `serialize_translation` expects `Transform`.
unsafe {
app.replicate_with::<Transform>(ComponentFns {
serialize: serialize_translation,
deserialize: deserialize_translation,
remove: replication_fns::remove::<Transform>,
});
}
/// Serializes only `translation` from [`Transform`].
fn serialize_translation(component: Ptr, cursor: &mut Cursor<Vec<u8>>) -> bincode::Result<()> {
// SAFETY: function called for registered `ComponentId`.
let transform: &Transform = unsafe { component.deref() };
///
/// # Safety
///
/// [`Transform`] must be the erased pointee type for this [`Ptr`].
unsafe fn serialize_translation(component: Ptr, cursor: &mut Cursor<Vec<u8>>) -> bincode::Result<()> {
let transform: &Transform = component.deref();
bincode::serialize_into(cursor, &transform.translation)
}
Expand All @@ -90,10 +102,10 @@ pub trait AppReplicationExt {
}
```
The [`remove`](replication_fns::remove) used in this example is the default component
The [`remove`](super::replication_fns::remove) used in this example is the default component
removal function, but you can replace it with your own as well.
*/
fn replicate_with<C>(&mut self, component_fns: ComponentFns) -> &mut Self
unsafe fn replicate_with<C>(&mut self, component_fns: ComponentFns) -> &mut Self
where
C: Component;

Expand Down Expand Up @@ -140,16 +152,16 @@ pub trait AppReplicationExt {
}

impl AppReplicationExt for App {
fn replicate_with<C>(&mut self, component_fns: ComponentFns) -> &mut Self
unsafe fn replicate_with<C>(&mut self, component_fns: ComponentFns) -> &mut Self
where
C: Component,
{
let component_id = self.world.init_component::<C>();
let mut replication_fns = self.world.resource_mut::<ReplicationFns>();
let fns_id = replication_fns.register_component_fns(component_fns);

let mut rules = self.world.resource_mut::<ReplicationRules>();
rules.insert(ReplicationRule::new(vec![(component_id, fns_id)]));
let rule = ReplicationRule::new(vec![(component_id, fns_id)]);
self.world.resource_mut::<ReplicationRules>().insert(rule);

self
}
Expand Down Expand Up @@ -189,18 +201,28 @@ pub struct ReplicationRule {
pub priority: usize,

/// Rule components and their serialization/deserialization/removal functions.
pub components: Vec<(ComponentId, ComponentFnsId)>,
components: Vec<(ComponentId, ComponentFnsId)>,
}

impl ReplicationRule {
/// Creates a new rule with priority equal to the number of serialized components.
pub fn new(components: Vec<(ComponentId, ComponentFnsId)>) -> Self {
///
/// # Safety
///
/// Caller must ensure that in each pair the associated component can be safely
/// passed to the associated [`ComponentFns::serialize`].
pub unsafe fn new(components: Vec<(ComponentId, ComponentFnsId)>) -> Self {
Self {
priority: components.len(),
components,
}
}

/// Returns associated components and functions IDs.
pub(crate) fn components(&self) -> &[(ComponentId, ComponentFnsId)] {
&self.components
}

/// Determines whether an archetype contains all components required by the rule.
pub(crate) fn matches(&self, archetype: &Archetype) -> bool {
self.components
Expand Down Expand Up @@ -293,7 +315,8 @@ impl GroupReplication for PlayerBundle {
(visibility_id, visibility_fns_id),
];
ReplicationRule::new(components)
// SAFETY: all components can be safely passed to their serialization functions.
unsafe { ReplicationRule::new(components) }
}
}
Expand All @@ -314,15 +337,12 @@ macro_rules! impl_registrations {
let mut components = Vec::new();
$(
let component_id = world.init_component::<$type>();
let fns_id = replication_fns.register_component_fns(ComponentFns {
serialize: replication_fns::serialize::<$type>,
deserialize: replication_fns::deserialize::<$type>,
remove: replication_fns::remove::<$type>,
});
let fns_id = replication_fns.register_component_fns(ComponentFns::default_fns::<$type>());
components.push((component_id, fns_id));
)*

ReplicationRule::new(components)
// SAFETY: Components are registered with the appropriate default serialization functions.
unsafe { ReplicationRule::new(components) }
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/scene.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn replicate_into(scene: &mut DynamicScene, world: &World) {
}

for rule in rules.iter().filter(|rule| rule.matches(archetype)) {
for &(component_id, _) in &rule.components {
for &(component_id, _) in rule.components() {
// SAFETY: replication rules can be registered only with valid component IDs.
let replicated_component =
unsafe { world.components().get_info_unchecked(component_id) };
Expand Down
4 changes: 2 additions & 2 deletions src/server/removal_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ impl FromWorld for ReplicatedComponents {
let rules = world.resource::<ReplicationRules>();
let component_ids = rules
.iter()
.flat_map(|rule| &rule.components)
.flat_map(|rule| rule.components())
.map(|&(component_id, _)| component_id)
.collect();

Expand Down Expand Up @@ -172,7 +172,7 @@ impl RemovalBuffer {
.iter()
.filter(|rule| rule.matches_removals(archetype, components))
{
for &(component_id, fns_id) in &rule.components {
for &(component_id, fns_id) in rule.components() {
// Since rules are sorted by priority,
// we are inserting only new components that aren't present.
if removed_ids
Expand Down
2 changes: 1 addition & 1 deletion src/server/replicated_archetypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl ReplicatedArchetypes {
{
let mut replicated_archetype = ReplicatedArchetype::new(archetype.id());
for rule in rules.iter().filter(|rule| rule.matches(archetype)) {
for &(component_id, fns_id) in &rule.components {
for &(component_id, fns_id) in rule.components() {
// Since rules are sorted by priority,
// we are inserting only new components that aren't present.
if replicated_archetype
Expand Down
8 changes: 6 additions & 2 deletions src/server/replication_messages.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,9 @@ impl InitMessage {

let size = write_with(shared_bytes, &mut self.cursor, |cursor| {
DefaultOptions::new().serialize_into(&mut *cursor, &fns_id)?;
(fns.serialize)(ptr, cursor)
// SAFETY: User ensured that the registered component can be
// safely passed to its serialization function.
unsafe { (fns.serialize)(ptr, cursor) }
})?;

self.entity_data_size = self
Expand Down Expand Up @@ -525,7 +527,9 @@ impl UpdateMessage {

let size = write_with(shared_bytes, &mut self.cursor, |cursor| {
DefaultOptions::new().serialize_into(&mut *cursor, &fns_id)?;
(fns.serialize)(ptr, cursor)
// SAFETY: User ensured that the registered component can be
// safely passed to its serialization function.
unsafe { (fns.serialize)(ptr, cursor) }
})?;

self.entity_data_size = self
Expand Down

0 comments on commit 70a3bdc

Please sign in to comment.