Skip to content

Commit

Permalink
Live reloading of shaders (#937)
Browse files Browse the repository at this point in the history
* Add ShaderLoader, rebuild pipelines for modified shader assets
* New example
* Add shader_update_system, ShaderError, remove specialization assets
* Don't panic on shader compilation failure
  • Loading branch information
yrns authored Dec 7, 2020
1 parent a3bca7e commit 2c9b795
Show file tree
Hide file tree
Showing 11 changed files with 329 additions and 52 deletions.
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ path = "examples/reflection/trait_reflection.rs"
name = "scene"
path = "examples/scene/scene.rs"

[[example]]
name = "hot_shader_reloading"
path = "examples/shader/hot_shader_reloading.rs"

[[example]]
name = "mesh_custom_attribute"
path = "examples/shader/mesh_custom_attribute.rs"
Expand Down
11 changes: 11 additions & 0 deletions assets/shaders/hot.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#version 450

layout(location = 0) out vec4 o_Target;

layout(set = 2, binding = 0) uniform MyMaterial_color {
vec4 color;
};

void main() {
o_Target = color * 0.5;
}
15 changes: 15 additions & 0 deletions assets/shaders/hot.vert
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#version 450

layout(location = 0) in vec3 Vertex_Position;

layout(set = 0, binding = 0) uniform Camera {
mat4 ViewProj;
};

layout(set = 1, binding = 0) uniform Transform {
mat4 Model;
};

void main() {
gl_Position = ViewProj * Model * vec4(Vertex_Position, 1.0);
}
4 changes: 4 additions & 0 deletions crates/bevy_render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use render_graph::{
RenderGraph,
};
use renderer::{AssetRenderResourceBindings, RenderResourceBindings};
use shader::ShaderLoader;
#[cfg(feature = "hdr")]
use texture::HdrTextureLoader;
#[cfg(feature = "png")]
Expand Down Expand Up @@ -87,6 +88,8 @@ impl Plugin for RenderPlugin {
app.init_asset_loader::<HdrTextureLoader>();
}

app.init_asset_loader::<ShaderLoader>();

if app.resources().get::<ClearColor>().is_none() {
app.resources_mut().insert(ClearColor::default());
}
Expand Down Expand Up @@ -134,6 +137,7 @@ impl Plugin for RenderPlugin {
camera::visible_entities_system,
)
// TODO: turn these "resource systems" into graph nodes and remove the RENDER_RESOURCE stage
.add_system_to_stage(stage::RENDER_RESOURCE, shader::shader_update_system)
.add_system_to_stage(stage::RENDER_RESOURCE, mesh::mesh_resource_provider_system)
.add_system_to_stage(stage::RENDER_RESOURCE, Texture::texture_resource_system)
.add_system_to_stage(
Expand Down
109 changes: 91 additions & 18 deletions crates/bevy_render/src/pipeline/pipeline_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{state_descriptors::PrimitiveTopology, IndexFormat, PipelineDescripto
use crate::{
pipeline::{BindType, InputStepMode, VertexBufferDescriptor},
renderer::RenderResourceContext,
shader::{Shader, ShaderSource},
shader::{Shader, ShaderError, ShaderSource},
};
use bevy_asset::{Assets, Handle};
use bevy_reflect::Reflect;
Expand Down Expand Up @@ -60,6 +60,7 @@ struct SpecializedPipeline {
#[derive(Debug, Default)]
pub struct PipelineCompiler {
specialized_shaders: HashMap<Handle<Shader>, Vec<SpecializedShader>>,
specialized_shader_pipelines: HashMap<Handle<Shader>, Vec<Handle<PipelineDescriptor>>>,
specialized_pipelines: HashMap<Handle<PipelineDescriptor>, Vec<SpecializedPipeline>>,
}

Expand All @@ -70,7 +71,7 @@ impl PipelineCompiler {
shaders: &mut Assets<Shader>,
shader_handle: &Handle<Shader>,
shader_specialization: &ShaderSpecialization,
) -> Handle<Shader> {
) -> Result<Handle<Shader>, ShaderError> {
let specialized_shaders = self
.specialized_shaders
.entry(shader_handle.clone_weak())
Expand All @@ -80,7 +81,7 @@ impl PipelineCompiler {

// don't produce new shader if the input source is already spirv
if let ShaderSource::Spirv(_) = shader.source {
return shader_handle.clone_weak();
return Ok(shader_handle.clone_weak());
}

if let Some(specialized_shader) =
Expand All @@ -91,7 +92,7 @@ impl PipelineCompiler {
})
{
// if shader has already been compiled with current configuration, use existing shader
specialized_shader.shader.clone_weak()
Ok(specialized_shader.shader.clone_weak())
} else {
// if no shader exists with the current configuration, create new shader and compile
let shader_def_vec = shader_specialization
Expand All @@ -100,14 +101,14 @@ impl PipelineCompiler {
.cloned()
.collect::<Vec<String>>();
let compiled_shader =
render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec));
render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec))?;
let specialized_handle = shaders.add(compiled_shader);
let weak_specialized_handle = specialized_handle.clone_weak();
specialized_shaders.push(SpecializedShader {
shader: specialized_handle,
specialization: shader_specialization.clone(),
});
weak_specialized_handle
Ok(weak_specialized_handle)
}
}

Expand Down Expand Up @@ -138,23 +139,31 @@ impl PipelineCompiler {
) -> Handle<PipelineDescriptor> {
let source_descriptor = pipelines.get(source_pipeline).unwrap();
let mut specialized_descriptor = source_descriptor.clone();
specialized_descriptor.shader_stages.vertex = self.compile_shader(
render_resource_context,
shaders,
&specialized_descriptor.shader_stages.vertex,
&pipeline_specialization.shader_specialization,
);
let specialized_vertex_shader = self
.compile_shader(
render_resource_context,
shaders,
&specialized_descriptor.shader_stages.vertex,
&pipeline_specialization.shader_specialization,
)
.unwrap();
specialized_descriptor.shader_stages.vertex = specialized_vertex_shader.clone_weak();
let mut specialized_fragment_shader = None;
specialized_descriptor.shader_stages.fragment = specialized_descriptor
.shader_stages
.fragment
.as_ref()
.map(|fragment| {
self.compile_shader(
render_resource_context,
shaders,
fragment,
&pipeline_specialization.shader_specialization,
)
let shader = self
.compile_shader(
render_resource_context,
shaders,
fragment,
&pipeline_specialization.shader_specialization,
)
.unwrap();
specialized_fragment_shader = Some(shader.clone_weak());
shader
});

let mut layout = render_resource_context.reflect_pipeline_layout(
Expand Down Expand Up @@ -244,6 +253,18 @@ impl PipelineCompiler {
&shaders,
);

// track specialized shader pipelines
self.specialized_shader_pipelines
.entry(specialized_vertex_shader)
.or_insert_with(Default::default)
.push(source_pipeline.clone_weak());
if let Some(specialized_fragment_shader) = specialized_fragment_shader {
self.specialized_shader_pipelines
.entry(specialized_fragment_shader)
.or_insert_with(Default::default)
.push(source_pipeline.clone_weak());
}

let specialized_pipelines = self
.specialized_pipelines
.entry(source_pipeline.clone_weak())
Expand Down Expand Up @@ -282,4 +303,56 @@ impl PipelineCompiler {
})
.flatten()
}

/// Update specialized shaders and remove any related specialized
/// pipelines and assets.
pub fn update_shader(
&mut self,
shader: &Handle<Shader>,
pipelines: &mut Assets<PipelineDescriptor>,
shaders: &mut Assets<Shader>,
render_resource_context: &dyn RenderResourceContext,
) -> Result<(), ShaderError> {
if let Some(specialized_shaders) = self.specialized_shaders.get_mut(shader) {
for specialized_shader in specialized_shaders {
// Recompile specialized shader. If it fails, we bail immediately.
let shader_def_vec = specialized_shader
.specialization
.shader_defs
.iter()
.cloned()
.collect::<Vec<String>>();
let new_handle =
shaders.add(render_resource_context.get_specialized_shader(
shaders.get(shader).unwrap(),
Some(&shader_def_vec),
)?);

// Replace handle and remove old from assets.
let old_handle = std::mem::replace(&mut specialized_shader.shader, new_handle);
shaders.remove(&old_handle);

// Find source pipelines that use the old specialized
// shader, and remove from tracking.
if let Some(source_pipelines) =
self.specialized_shader_pipelines.remove(&old_handle)
{
// Remove all specialized pipelines from tracking
// and asset storage. They will be rebuilt on next
// draw.
for source_pipeline in source_pipelines {
if let Some(specialized_pipelines) =
self.specialized_pipelines.remove(&source_pipeline)
{
for p in specialized_pipelines {
pipelines.remove(p.pipeline);
}
}
}
}
}
}

Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::RenderResourceContext;
use crate::{
pipeline::{BindGroupDescriptorId, PipelineDescriptor},
renderer::{BindGroup, BufferId, BufferInfo, RenderResourceId, SamplerId, TextureId},
shader::Shader,
shader::{Shader, ShaderError},
texture::{SamplerDescriptor, TextureDescriptor},
};
use bevy_asset::{Assets, Handle, HandleUntyped};
Expand Down Expand Up @@ -149,8 +149,12 @@ impl RenderResourceContext for HeadlessRenderResourceContext {
size
}

fn get_specialized_shader(&self, shader: &Shader, _macros: Option<&[String]>) -> Shader {
shader.clone()
fn get_specialized_shader(
&self,
shader: &Shader,
_macros: Option<&[String]>,
) -> Result<Shader, ShaderError> {
Ok(shader.clone())
}

fn remove_stale_bind_groups(&self) {}
Expand Down
8 changes: 6 additions & 2 deletions crates/bevy_render/src/renderer/render_resource_context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
pipeline::{BindGroupDescriptorId, PipelineDescriptor, PipelineLayout},
renderer::{BindGroup, BufferId, BufferInfo, RenderResourceId, SamplerId, TextureId},
shader::{Shader, ShaderLayout, ShaderStages},
shader::{Shader, ShaderError, ShaderLayout, ShaderStages},
texture::{SamplerDescriptor, TextureDescriptor},
};
use bevy_asset::{Asset, Assets, Handle, HandleUntyped};
Expand Down Expand Up @@ -29,7 +29,11 @@ pub trait RenderResourceContext: Downcast + Send + Sync + 'static {
fn create_buffer_with_data(&self, buffer_info: BufferInfo, data: &[u8]) -> BufferId;
fn create_shader_module(&self, shader_handle: &Handle<Shader>, shaders: &Assets<Shader>);
fn create_shader_module_from_source(&self, shader_handle: &Handle<Shader>, shader: &Shader);
fn get_specialized_shader(&self, shader: &Shader, macros: Option<&[String]>) -> Shader;
fn get_specialized_shader(
&self,
shader: &Shader,
macros: Option<&[String]>,
) -> Result<Shader, ShaderError>;
fn remove_buffer(&self, buffer: BufferId);
fn remove_texture(&self, texture: TextureId);
fn remove_sampler(&self, sampler: SamplerId);
Expand Down
Loading

0 comments on commit 2c9b795

Please sign in to comment.