Skip to content

Commit

Permalink
Remove task_pool parameter from par_for_each(_mut) (#4705)
Browse files Browse the repository at this point in the history
# Objective
Fixes #3183. Requiring a `&TaskPool` parameter is sort of meaningless if the only correct one is to use the one provided by `Res<ComputeTaskPool>` all the time.

## Solution
Have `QueryState` save a clone of the `ComputeTaskPool` which is used for all `par_for_each` functions.

~~Adds a small overhead of the internal `Arc` clone as a part of the startup, but the ergonomics win should be well worth this hardly-noticable overhead.~~

Updated the docs to note that it will panic the task pool is not present as a resource.

# Future Work
If bevyengine/rfcs#54 is approved, we can replace these resource lookups with a static function call instead to get the `ComputeTaskPool`.

---

## Changelog
Removed: The `task_pool` parameter of `Query(State)::par_for_each(_mut)`. These calls will use the `World`'s `ComputeTaskPool` resource instead.

## Migration Guide
The `task_pool` parameter for `Query(State)::par_for_each(_mut)` has been removed. Remove these parameters from all calls to these functions.

Before:
```rust
fn parallel_system(
   task_pool: Res<ComputeTaskPool>,
   query: Query<&MyComponent>,
) {
   query.par_for_each(&task_pool, 32, |comp| {
        ...
   });
}
```

After:

```rust
fn parallel_system(query: Query<&MyComponent>) {
   query.par_for_each(32, |comp| {
        ...
   });
}
```

If using `Query(State)` outside of a system run by the scheduler, you may need to manually configure and initialize a `ComputeTaskPool` as a resource in the `World`.
  • Loading branch information
james7132 committed May 30, 2022
1 parent f59ea7e commit c5e8989
Show file tree
Hide file tree
Showing 5 changed files with 149 additions and 125 deletions.
8 changes: 4 additions & 4 deletions benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use bevy_ecs::prelude::*;
use bevy_tasks::TaskPool;
use bevy_tasks::{ComputeTaskPool, TaskPool};
use glam::*;

#[derive(Component, Copy, Clone)]
Expand Down Expand Up @@ -29,8 +29,8 @@ impl Benchmark {
)
}));

fn sys(task_pool: Res<TaskPool>, mut query: Query<(&mut Position, &mut Transform)>) {
query.par_for_each_mut(&task_pool, 128, |(mut pos, mut mat)| {
fn sys(mut query: Query<(&mut Position, &mut Transform)>) {
query.par_for_each_mut(128, |(mut pos, mut mat)| {
for _ in 0..100 {
mat.0 = mat.0.inverse();
}
Expand All @@ -39,7 +39,7 @@ impl Benchmark {
});
}

world.insert_resource(TaskPool::default());
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let mut system = IntoSystem::into_system(sys);
system.initialize(&mut world);
system.update_archetype_component_access(&world);
Expand Down
10 changes: 4 additions & 6 deletions crates/bevy_ecs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ mod tests {
query::{Added, ChangeTrackers, Changed, FilteredAccess, With, Without, WorldQuery},
world::{Mut, World},
};
use bevy_tasks::TaskPool;
use bevy_tasks::{ComputeTaskPool, TaskPool};
use std::{
any::TypeId,
sync::{
Expand Down Expand Up @@ -376,7 +376,7 @@ mod tests {
#[test]
fn par_for_each_dense() {
let mut world = World::new();
let task_pool = TaskPool::default();
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let e1 = world.spawn().insert(A(1)).id();
let e2 = world.spawn().insert(A(2)).id();
let e3 = world.spawn().insert(A(3)).id();
Expand All @@ -385,7 +385,7 @@ mod tests {
let results = Arc::new(Mutex::new(Vec::new()));
world
.query::<(Entity, &A)>()
.par_for_each(&world, &task_pool, 2, |(e, &A(i))| {
.par_for_each(&world, 2, |(e, &A(i))| {
results.lock().unwrap().push((e, i));
});
results.lock().unwrap().sort();
Expand All @@ -398,8 +398,7 @@ mod tests {
#[test]
fn par_for_each_sparse() {
let mut world = World::new();

let task_pool = TaskPool::default();
world.insert_resource(ComputeTaskPool(TaskPool::default()));
let e1 = world.spawn().insert(SparseStored(1)).id();
let e2 = world.spawn().insert(SparseStored(2)).id();
let e3 = world.spawn().insert(SparseStored(3)).id();
Expand All @@ -408,7 +407,6 @@ mod tests {
let results = Arc::new(Mutex::new(Vec::new()));
world.query::<(Entity, &SparseStored)>().par_for_each(
&world,
&task_pool,
2,
|(e, &SparseStored(i))| results.lock().unwrap().push((e, i)),
);
Expand Down
216 changes: 120 additions & 96 deletions crates/bevy_ecs/src/query/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,18 @@ use crate::{
storage::TableId,
world::{World, WorldId},
};
use bevy_tasks::TaskPool;
use bevy_tasks::{ComputeTaskPool, TaskPool};
#[cfg(feature = "trace")]
use bevy_utils::tracing::Instrument;
use fixedbitset::FixedBitSet;
use std::fmt;
use std::{fmt, ops::Deref};

use super::{QueryFetch, QueryItem, ROQueryFetch, ROQueryItem};

/// Provides scoped access to a [`World`] state according to a given [`WorldQuery`] and query filter.
pub struct QueryState<Q: WorldQuery, F: WorldQuery = ()> {
world_id: WorldId,
task_pool: Option<TaskPool>,
pub(crate) archetype_generation: ArchetypeGeneration,
pub(crate) matched_tables: FixedBitSet,
pub(crate) matched_archetypes: FixedBitSet,
Expand Down Expand Up @@ -61,6 +62,9 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {

let mut state = Self {
world_id: world.id(),
task_pool: world
.get_resource::<ComputeTaskPool>()
.map(|task_pool| task_pool.deref().clone()),
archetype_generation: ArchetypeGeneration::initial(),
matched_table_ids: Vec::new(),
matched_archetype_ids: Vec::new(),
Expand Down Expand Up @@ -689,15 +693,18 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
);
}

/// Runs `func` on each query result in parallel using the given `task_pool`.
/// Runs `func` on each query result in parallel.
///
/// This can only be called for read-only queries, see [`Self::par_for_each_mut`] for
/// write-queries.
///
/// # Panics
/// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
#[inline]
pub fn par_for_each<'w, FN: Fn(ROQueryItem<'w, Q>) + Send + Sync + Clone>(
&mut self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
) {
Expand All @@ -706,7 +713,6 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
self.update_archetypes(world);
self.par_for_each_unchecked_manual::<ROQueryFetch<Q>, FN>(
world,
task_pool,
batch_size,
func,
world.last_change_tick(),
Expand All @@ -715,12 +721,15 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
}
}

/// Runs `func` on each query result in parallel using the given `task_pool`.
/// Runs `func` on each query result in parallel.
///
/// # Panics
/// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
#[inline]
pub fn par_for_each_mut<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(
&mut self,
world: &'w mut World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
) {
Expand All @@ -729,7 +738,6 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
self.update_archetypes(world);
self.par_for_each_unchecked_manual::<QueryFetch<Q>, FN>(
world,
task_pool,
batch_size,
func,
world.last_change_tick(),
Expand All @@ -738,10 +746,14 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
}
}

/// Runs `func` on each query result in parallel using the given `task_pool`.
/// Runs `func` on each query result in parallel.
///
/// This can only be called for read-only queries.
///
/// # Panics
/// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
///
/// # Safety
///
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
Expand All @@ -750,14 +762,12 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
pub unsafe fn par_for_each_unchecked<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>(
&mut self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
) {
self.update_archetypes(world);
self.par_for_each_unchecked_manual::<QueryFetch<Q>, FN>(
world,
task_pool,
batch_size,
func,
world.last_change_tick(),
Expand Down Expand Up @@ -833,6 +843,10 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
/// the current change tick are given. This is faster than the equivalent
/// iter() method, but cannot be chained like a normal [`Iterator`].
///
/// # Panics
/// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query
/// that is being initialized and run from the ECS scheduler, this should never panic.
///
/// # Safety
///
/// This does not check for mutable query correctness. To be safe, make sure mutable queries
Expand All @@ -846,103 +860,113 @@ impl<Q: WorldQuery, F: WorldQuery> QueryState<Q, F> {
>(
&self,
world: &'w World,
task_pool: &TaskPool,
batch_size: usize,
func: FN,
last_change_tick: u32,
change_tick: u32,
) {
// NOTE: If you are changing query iteration code, remember to update the following places, where relevant:
// QueryIter, QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual
task_pool.scope(|scope| {
if QF::IS_DENSE && <QueryFetch<'static, F>>::IS_DENSE {
let tables = &world.storages().tables;
for table_id in &self.matched_table_ids {
let table = &tables[*table_id];
let mut offset = 0;
while offset < table.len() {
let func = func.clone();
let len = batch_size.min(table.len() - offset);
let task = async move {
let mut fetch =
QF::init(world, &self.fetch_state, last_change_tick, change_tick);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let table = &tables[*table_id];
fetch.set_table(&self.fetch_state, table);
filter.set_table(&self.filter_state, table);
for table_index in offset..offset + len {
if !filter.table_filter_fetch(table_index) {
continue;
self.task_pool
.as_ref()
.expect("Cannot iterate query in parallel. No ComputeTaskPool initialized.")
.scope(|scope| {
if QF::IS_DENSE && <QueryFetch<'static, F>>::IS_DENSE {
let tables = &world.storages().tables;
for table_id in &self.matched_table_ids {
let table = &tables[*table_id];
let mut offset = 0;
while offset < table.len() {
let func = func.clone();
let len = batch_size.min(table.len() - offset);
let task = async move {
let mut fetch = QF::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let table = &tables[*table_id];
fetch.set_table(&self.fetch_state, table);
filter.set_table(&self.filter_state, table);
for table_index in offset..offset + len {
if !filter.table_filter_fetch(table_index) {
continue;
}
let item = fetch.table_fetch(table_index);
func(item);
}
let item = fetch.table_fetch(table_index);
func(item);
}
};
#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);
scope.spawn(task);
offset += batch_size;
}
}
} else {
let archetypes = &world.archetypes;
for archetype_id in &self.matched_archetype_ids {
let mut offset = 0;
let archetype = &archetypes[*archetype_id];
while offset < archetype.len() {
let func = func.clone();
let len = batch_size.min(archetype.len() - offset);
let task = async move {
let mut fetch =
QF::init(world, &self.fetch_state, last_change_tick, change_tick);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
};
#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
let tables = &world.storages().tables;
let archetype = &world.archetypes[*archetype_id];
fetch.set_archetype(&self.fetch_state, archetype, tables);
filter.set_archetype(&self.filter_state, archetype, tables);

for archetype_index in offset..offset + len {
if !filter.archetype_filter_fetch(archetype_index) {
continue;
#[cfg(feature = "trace")]
let task = task.instrument(span);
scope.spawn(task);
offset += batch_size;
}
}
} else {
let archetypes = &world.archetypes;
for archetype_id in &self.matched_archetype_ids {
let mut offset = 0;
let archetype = &archetypes[*archetype_id];
while offset < archetype.len() {
let func = func.clone();
let len = batch_size.min(archetype.len() - offset);
let task = async move {
let mut fetch = QF::init(
world,
&self.fetch_state,
last_change_tick,
change_tick,
);
let mut filter = <QueryFetch<F> as Fetch>::init(
world,
&self.filter_state,
last_change_tick,
change_tick,
);
let tables = &world.storages().tables;
let archetype = &world.archetypes[*archetype_id];
fetch.set_archetype(&self.fetch_state, archetype, tables);
filter.set_archetype(&self.filter_state, archetype, tables);

for archetype_index in offset..offset + len {
if !filter.archetype_filter_fetch(archetype_index) {
continue;
}
func(fetch.archetype_fetch(archetype_index));
}
func(fetch.archetype_fetch(archetype_index));
}
};

#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);

scope.spawn(task);
offset += batch_size;
};

#[cfg(feature = "trace")]
let span = bevy_utils::tracing::info_span!(
"par_for_each",
query = std::any::type_name::<Q>(),
filter = std::any::type_name::<F>(),
count = len,
);
#[cfg(feature = "trace")]
let task = task.instrument(span);

scope.spawn(task);
offset += batch_size;
}
}
}
}
});
});
}

/// Returns a single immutable query result when there is exactly one entity matching
Expand Down
Loading

0 comments on commit c5e8989

Please sign in to comment.