From b0bb3b123a0f65380ff0c313f85e27a2abe15ee8 Mon Sep 17 00:00:00 2001 From: khyperia <953151+khyperia@users.noreply.github.com> Date: Mon, 23 Aug 2021 13:32:42 +0200 Subject: [PATCH 1/3] Implement ByteAddressableBuffer prototype --- crates/rustc_codegen_spirv/src/attr.rs | 22 ++ .../src/builder/builder_methods.rs | 22 +- .../src/builder/byte_addressable_buffer.rs | 347 ++++++++++++++++++ crates/rustc_codegen_spirv/src/builder/mod.rs | 1 + .../rustc_codegen_spirv/src/builder_spirv.rs | 17 + .../src/codegen_cx/declare.rs | 6 + .../rustc_codegen_spirv/src/codegen_cx/mod.rs | 6 + crates/rustc_codegen_spirv/src/symbols.rs | 5 + .../spirv-std/src/byte_addressable_buffer.rs | 56 +++ crates/spirv-std/src/lib.rs | 2 + tests/ui/byte_addressable_buffer/arr.rs | 15 + .../ui/byte_addressable_buffer/big_struct.rs | 24 ++ tests/ui/byte_addressable_buffer/complex.rs | 30 ++ tests/ui/byte_addressable_buffer/f32.rs | 15 + tests/ui/byte_addressable_buffer/u32.rs | 15 + tests/ui/byte_addressable_buffer/vec.rs | 15 + 16 files changed, 596 insertions(+), 2 deletions(-) create mode 100644 crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs create mode 100644 crates/spirv-std/src/byte_addressable_buffer.rs create mode 100644 tests/ui/byte_addressable_buffer/arr.rs create mode 100644 tests/ui/byte_addressable_buffer/big_struct.rs create mode 100644 tests/ui/byte_addressable_buffer/complex.rs create mode 100644 tests/ui/byte_addressable_buffer/f32.rs create mode 100644 tests/ui/byte_addressable_buffer/u32.rs create mode 100644 tests/ui/byte_addressable_buffer/vec.rs diff --git a/crates/rustc_codegen_spirv/src/attr.rs b/crates/rustc_codegen_spirv/src/attr.rs index 5067ab429a..2916234a13 100644 --- a/crates/rustc_codegen_spirv/src/attr.rs +++ b/crates/rustc_codegen_spirv/src/attr.rs @@ -89,6 +89,8 @@ pub enum SpirvAttribute { // `fn`/closure attributes: UnrollLoops, + BufferLoadIntrinsic, + BufferStoreIntrinsic, } // HACK(eddyb) this is similar to `rustc_span::Spanned` but with `value` as the @@ -122,6 +124,8 @@ pub struct AggregatedSpirvAttributes { // `fn`/closure attributes: pub unroll_loops: Option>, + pub buffer_load_intrinsic: Option>, + pub buffer_store_intrinsic: Option>, } struct MultipleAttrs { @@ -209,6 +213,18 @@ impl AggregatedSpirvAttributes { "#[spirv(attachment_index)]", ), UnrollLoops => try_insert(&mut self.unroll_loops, (), span, "#[spirv(unroll_loops)]"), + BufferLoadIntrinsic => try_insert( + &mut self.buffer_load_intrinsic, + (), + span, + "#[spirv(buffer_load_intrinsic)]", + ), + BufferStoreIntrinsic => try_insert( + &mut self.buffer_store_intrinsic, + (), + span, + "#[spirv(buffer_store_intrinsic)]", + ), } } } @@ -342,6 +358,12 @@ impl CheckSpirvAttrVisitor<'_> { _ => Err(Expected("function or closure")), }, + SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => { + match target { + Target::Fn => Ok(()), + _ => Err(Expected("function")), + } + } }; match valid_target { Err(Expected(expected_target)) => self.tcx.sess.span_err( diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 0201bf86c0..50b7d9efdc 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -1845,8 +1845,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { SpirvType::Adt { field_types, .. } => field_types[idx as usize], SpirvType::Array { element, .. } | SpirvType::Vector { element, .. } => element, other => self.fatal(&format!( - "extract_value not implemented on type {:?}", - other + "extract_value not implemented on type {}", + other.debug(agg_val.ty, self) )), }; self.emit() @@ -2196,6 +2196,24 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // needing to materialize `&core::panic::Location` or `format_args!`. self.abort(); self.undef(result_type) + } else if self + .buffer_load_intrinsic_fn_id + .borrow() + .contains(&callee_val) + { + self.codegen_buffer_load_intrinsic(result_type, args) + } else if self + .buffer_store_intrinsic_fn_id + .borrow() + .contains(&callee_val) + { + self.codegen_buffer_store_intrinsic(args); + + let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self); + SpirvValue { + kind: SpirvValueKind::IllegalTypeUsed(void_ty), + ty: void_ty, + } } else { let args = args.iter().map(|arg| arg.def(self)).collect::>(); self.emit() diff --git a/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs new file mode 100644 index 0000000000..0450d370fd --- /dev/null +++ b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs @@ -0,0 +1,347 @@ +use super::Builder; +use crate::builder_spirv::{SpirvValue, SpirvValueExt}; +use crate::spirv_type::SpirvType; +use core::array::IntoIter; +use rspirv::spirv::Word; +use rustc_codegen_ssa::traits::{BaseTypeMethods, BuilderMethods}; +use rustc_span::DUMMY_SP; +use rustc_target::abi::Align; + +impl<'a, 'tcx> Builder<'a, 'tcx> { + fn load_err(&mut self, original_type: Word, invalid_type: Word) -> SpirvValue { + let mut err = self.struct_err(&format!( + "Cannot load type {} in an untyped buffer load", + self.debug_type(original_type) + )); + if original_type != invalid_type { + err.note(&format!( + "due to containing type {}", + self.debug_type(invalid_type) + )); + } + err.emit(); + self.undef(invalid_type) + } + + fn load_u32( + &mut self, + array: SpirvValue, + dynamic_index: SpirvValue, + constant_offset: u32, + ) -> SpirvValue { + let actual_index = if constant_offset != 0 { + let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset); + self.add(dynamic_index, const_offset_val) + } else { + dynamic_index + }; + let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self); + let u32_ptr = self.type_ptr_to(u32_ty); + let ptr = self + .emit() + .in_bounds_access_chain( + u32_ptr, + None, + array.def(self), + IntoIter::new([actual_index.def(self)]), + ) + .unwrap() + .with_type(u32_ptr); + self.load(u32_ty, ptr, Align::ONE) + } + + #[allow(clippy::too_many_arguments)] + fn load_vec_or_arr( + &mut self, + original_type: Word, + result_type: Word, + array: SpirvValue, + dynamic_word_index: SpirvValue, + constant_word_offset: u32, + element: Word, + count: u32, + ) -> SpirvValue { + let element_size_bytes = match self.lookup_type(element).sizeof(self) { + Some(size) => size, + None => return self.load_err(original_type, result_type), + }; + if element_size_bytes.bytes() % 4 != 0 { + return self.load_err(original_type, result_type); + } + let element_size_words = (element_size_bytes.bytes() / 4) as u32; + let args = (0..count) + .map(|index| { + self.recurse_load_type( + original_type, + element, + array, + dynamic_word_index, + constant_word_offset + element_size_words * index, + ) + .def(self) + }) + .collect::>(); + self.emit() + .composite_construct(result_type, None, args) + .unwrap() + .with_type(result_type) + } + + fn recurse_load_type( + &mut self, + original_type: Word, + result_type: Word, + array: SpirvValue, + dynamic_word_index: SpirvValue, + constant_word_offset: u32, + ) -> SpirvValue { + match self.lookup_type(result_type) { + SpirvType::Integer(32, signed) => { + let val = self.load_u32(array, dynamic_word_index, constant_word_offset); + self.intcast(val, result_type, signed) + } + SpirvType::Float(32) => { + let val = self.load_u32(array, dynamic_word_index, constant_word_offset); + self.bitcast(val, result_type) + } + SpirvType::Vector { element, count } => self.load_vec_or_arr( + original_type, + result_type, + array, + dynamic_word_index, + constant_word_offset, + element, + count, + ), + SpirvType::Array { element, count } => { + let count = match self.builder.lookup_const_u64(count) { + Some(count) => count as u32, + None => return self.load_err(original_type, result_type), + }; + self.load_vec_or_arr( + original_type, + result_type, + array, + dynamic_word_index, + constant_word_offset, + element, + count, + ) + } + SpirvType::Adt { + size: Some(_), + field_types, + field_offsets, + .. + } => { + let args = field_types + .iter() + .zip(field_offsets) + .map(|(&field_type, byte_offset)| { + if byte_offset.bytes() % 4 != 0 { + return None; + } + let word_offset = (byte_offset.bytes() / 4) as u32; + Some( + self.recurse_load_type( + original_type, + field_type, + array, + dynamic_word_index, + constant_word_offset + word_offset, + ) + .def(self), + ) + }) + .collect::>>(); + match args { + None => self.load_err(original_type, result_type), + Some(args) => self + .emit() + .composite_construct(result_type, None, args) + .unwrap() + .with_type(result_type), + } + } + + _ => self.load_err(original_type, result_type), + } + } + + /// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller. + pub fn codegen_buffer_load_intrinsic( + &mut self, + result_type: Word, + args: &[SpirvValue], + ) -> SpirvValue { + // Signature: fn load(array: &[u32], index: u32) -> T; + if args.len() != 3 { + self.fatal(&format!( + "buffer_load_intrinsic should have 3 args, it has {}", + args.len() + )); + } + // Note that the &[u32] gets split into two arguments - pointer, length + let array = args[0]; + let byte_index = args[2]; + let two = self.constant_u32(DUMMY_SP, 2); + let word_index = self.lshr(byte_index, two); + self.recurse_load_type(result_type, result_type, array, word_index, 0) + } + + fn store_err(&mut self, original_type: Word, value: SpirvValue) { + let mut err = self.struct_err(&format!( + "Cannot load type {} in an untyped buffer store", + self.debug_type(original_type) + )); + if original_type != value.ty { + err.note(&format!("due to containing type {}", value.ty)); + } + err.emit(); + } + + fn store_u32( + &mut self, + array: SpirvValue, + dynamic_index: SpirvValue, + constant_offset: u32, + value: SpirvValue, + ) { + let actual_index = if constant_offset != 0 { + let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset); + self.add(dynamic_index, const_offset_val) + } else { + dynamic_index + }; + let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self); + let u32_ptr = self.type_ptr_to(u32_ty); + let ptr = self + .emit() + .in_bounds_access_chain( + u32_ptr, + None, + array.def(self), + IntoIter::new([actual_index.def(self)]), + ) + .unwrap() + .with_type(u32_ptr); + self.store(value, ptr, Align::ONE); + } + + #[allow(clippy::too_many_arguments)] + fn store_vec_or_arr( + &mut self, + original_type: Word, + value: SpirvValue, + array: SpirvValue, + dynamic_word_index: SpirvValue, + constant_word_offset: u32, + element: Word, + count: u32, + ) { + let element_size_bytes = match self.lookup_type(element).sizeof(self) { + Some(size) => size, + None => return self.store_err(original_type, value), + }; + if element_size_bytes.bytes() % 4 != 0 { + return self.store_err(original_type, value); + } + let element_size_words = (element_size_bytes.bytes() / 4) as u32; + for index in 0..count { + let element = self.extract_value(value, index as u64); + self.recurse_store_type( + original_type, + element, + array, + dynamic_word_index, + constant_word_offset + element_size_words * index, + ); + } + } + + fn recurse_store_type( + &mut self, + original_type: Word, + value: SpirvValue, + array: SpirvValue, + dynamic_word_index: SpirvValue, + constant_word_offset: u32, + ) { + match self.lookup_type(value.ty) { + SpirvType::Integer(32, signed) => { + let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self); + let value_u32 = self.intcast(value, u32_ty, signed); + self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32); + } + SpirvType::Float(32) => { + let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self); + let value_u32 = self.bitcast(value, u32_ty); + self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32); + } + SpirvType::Vector { element, count } => self.store_vec_or_arr( + original_type, + value, + array, + dynamic_word_index, + constant_word_offset, + element, + count, + ), + SpirvType::Array { element, count } => { + let count = match self.builder.lookup_const_u64(count) { + Some(count) => count as u32, + None => return self.store_err(original_type, value), + }; + self.store_vec_or_arr( + original_type, + value, + array, + dynamic_word_index, + constant_word_offset, + element, + count, + ); + } + SpirvType::Adt { + size: Some(_), + field_offsets, + .. + } => { + for (index, byte_offset) in field_offsets.iter().enumerate() { + if byte_offset.bytes() % 4 != 0 { + return self.store_err(original_type, value); + } + let word_offset = (byte_offset.bytes() / 4) as u32; + let field = self.extract_value(value, index as u64); + self.recurse_store_type( + original_type, + field, + array, + dynamic_word_index, + constant_word_offset + word_offset, + ); + } + } + + _ => self.store_err(original_type, value), + } + } + + /// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller. + pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue]) { + // Signature: fn store(array: &[u32], index: u32, value: T); + if args.len() != 4 { + self.fatal(&format!( + "buffer_load_intrinsic should have 4 args, it has {}", + args.len() + )); + } + // Note that the &[u32] gets split into two arguments - pointer, length + let array = args[0]; + let byte_index = args[2]; + let two = self.constant_u32(DUMMY_SP, 2); + let word_index = self.lshr(byte_index, two); + let value = args[3]; + self.recurse_store_type(value.ty, value, array, word_index, 0); + } +} diff --git a/crates/rustc_codegen_spirv/src/builder/mod.rs b/crates/rustc_codegen_spirv/src/builder/mod.rs index 8364857342..e23078c313 100644 --- a/crates/rustc_codegen_spirv/src/builder/mod.rs +++ b/crates/rustc_codegen_spirv/src/builder/mod.rs @@ -1,4 +1,5 @@ mod builder_methods; +mod byte_addressable_buffer; mod ext_inst; mod intrinsics; pub mod libm_intrinsics; diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index 17207b4eed..b5b63f0d35 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -25,6 +25,13 @@ pub enum SpirvValueKind { /// of such constants, instead of where they're generated (and cached). IllegalConst(Word), + /// This can only happen in one specific case - which is as a result of + /// `codegen_buffer_store_intrinsic`, that function is supposed to return + /// OpTypeVoid, however because it gets inline by the compiler it can't. + /// Instead we return this, and trigger an error if we ever end up using the + /// result of this function call (which we can't). + IllegalTypeUsed(Word), + // FIXME(eddyb) this shouldn't be needed, but `rustc_codegen_ssa` still relies // on converting `Function`s to `Value`s even for direct calls, the `Builder` // should just have direct and indirect `call` variants (or a `Callee` enum). @@ -132,6 +139,16 @@ impl SpirvValue { id } + SpirvValueKind::IllegalTypeUsed(id) => { + cx.tcx + .sess + .struct_span_err(span, "Can't use type as a value") + .note(&format!("Type: *{}", cx.debug_type(id))) + .emit(); + + id + } + SpirvValueKind::FnAddr { .. } => { if cx.is_system_crate() { cx.builder diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index 53151bffdf..644e4d11b1 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -121,6 +121,12 @@ impl<'tcx> CodegenCx<'tcx> { if attrs.unroll_loops.is_some() { self.unroll_loops_decorations.borrow_mut().insert(fn_id); } + if attrs.buffer_load_intrinsic.is_some() { + self.buffer_load_intrinsic_fn_id.borrow_mut().insert(fn_id); + } + if attrs.buffer_store_intrinsic.is_some() { + self.buffer_store_intrinsic_fn_id.borrow_mut().insert(fn_id); + } let instance_def_id = instance.def_id(); diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index f8cb911313..53a10c48f9 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -66,6 +66,10 @@ pub struct CodegenCx<'tcx> { /// Simple `panic!("...")` and builtin panics (from MIR `Assert`s) call `#[lang = "panic"]`. pub panic_fn_id: Cell>, + /// Intrinsic for loading a from a &[u32] + pub buffer_load_intrinsic_fn_id: RefCell>, + /// Intrinsic for storing a into a &[u32] + pub buffer_store_intrinsic_fn_id: RefCell>, /// Builtin bounds-checking panics (from MIR `Assert`s) call `#[lang = "panic_bounds_check"]`. pub panic_bounds_check_fn_id: Cell>, @@ -123,6 +127,8 @@ impl<'tcx> CodegenCx<'tcx> { instruction_table: InstructionTable::new(), libm_intrinsics: Default::default(), panic_fn_id: Default::default(), + buffer_load_intrinsic_fn_id: Default::default(), + buffer_store_intrinsic_fn_id: Default::default(), panic_bounds_check_fn_id: Default::default(), i8_i16_atomics_allowed: false, codegen_args, diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index 59781098d9..407190e571 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -335,6 +335,11 @@ impl Symbols { SpirvAttribute::IntrinsicType(IntrinsicType::RuntimeArray), ), ("unroll_loops", SpirvAttribute::UnrollLoops), + ("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic), + ( + "buffer_store_intrinsic", + SpirvAttribute::BufferStoreIntrinsic, + ), ] .iter() .cloned(); diff --git a/crates/spirv-std/src/byte_addressable_buffer.rs b/crates/spirv-std/src/byte_addressable_buffer.rs new file mode 100644 index 0000000000..81688541ba --- /dev/null +++ b/crates/spirv-std/src/byte_addressable_buffer.rs @@ -0,0 +1,56 @@ +use core::mem; + +#[spirv(buffer_load_intrinsic)] +#[spirv_std_macros::gpu_only] +#[allow(improper_ctypes_definitions)] +unsafe extern "unadjusted" fn buffer_load_intrinsic(_buffer: &[u32], _offset: u32) -> T { + unimplemented!() +} // actually implemented in the compiler + +#[spirv(buffer_store_intrinsic)] +#[spirv_std_macros::gpu_only] +#[allow(improper_ctypes_definitions)] +unsafe extern "unadjusted" fn buffer_store_intrinsic( + _buffer: &mut [u32], + _offset: u32, + _value: T, +) { + unimplemented!() +} // actually implemented in the compiler + +#[repr(transparent)] +pub struct ByteAddressableBuffer<'a> { + pub data: &'a mut [u32], +} + +/// `ByteAddressableBuffer` is an untyped blob of data, allowing loads and stores of arbitrary +/// basic data types at arbitrary indicies. However, all data must be aligned to size 4, each +/// element within the data (e.g. struct fields) must have a size and alignment of a multiple of 4, +/// and the `byte_index` passed to load and store must be a multiple of 4 (`byte_index` will be +/// rounded down to the nearest multiple of 4). So, it's not technically a *byte* addressable +/// buffer, but rather a *word* buffer, but this naming and behavior was inhereted from HLSL (where +/// it's UB to pass in an index not a multiple of 4). +impl<'a> ByteAddressableBuffer<'a> { + #[inline] + pub fn new(data: &'a mut [u32]) -> Self { + Self { data } + } + + /// Loads an arbitrary type from the buffer. `byte_index` must be a multiple of 4, otherwise, + /// it will get silently rounded down to the nearest multiple of 4. + pub fn load(self, byte_index: u32) -> T { + if byte_index + mem::size_of::() as u32 > self.data.len() as u32 { + panic!("Index out of range") + } + unsafe { buffer_load_intrinsic(self.data, byte_index) } + } + + /// Stores an arbitrary type int the buffer. `byte_index` must be a multiple of 4, otherwise, + /// it will get silently rounded down to the nearest multiple of 4. + pub fn store(self, byte_index: u32, value: T) { + if byte_index + mem::size_of::() as u32 > self.data.len() as u32 { + panic!("Index out of range") + } + unsafe { buffer_store_intrinsic(self.data, byte_index, value) } + } +} diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index 8c4c2080c6..16a010330b 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -96,6 +96,7 @@ pub extern crate spirv_std_macros as macros; pub mod arch; +pub mod byte_addressable_buffer; pub mod float; pub mod image; pub mod integer; @@ -109,6 +110,7 @@ pub mod vector; pub use self::sampler::Sampler; pub use crate::macros::Image; +pub use byte_addressable_buffer::ByteAddressableBuffer; pub use num_traits; pub use runtime_array::*; diff --git a/tests/ui/byte_addressable_buffer/arr.rs b/tests/ui/byte_addressable_buffer/arr.rs new file mode 100644 index 0000000000..a6f35314e9 --- /dev/null +++ b/tests/ui/byte_addressable_buffer/arr.rs @@ -0,0 +1,15 @@ +// build-pass + +use spirv_std::{glam::Vec4, ByteAddressableBuffer}; + +#[spirv(fragment)] +pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut [i32; 4]) { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); +} + +#[spirv(fragment)] +pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: [i32; 4]) { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); +} diff --git a/tests/ui/byte_addressable_buffer/big_struct.rs b/tests/ui/byte_addressable_buffer/big_struct.rs new file mode 100644 index 0000000000..bfcbe79719 --- /dev/null +++ b/tests/ui/byte_addressable_buffer/big_struct.rs @@ -0,0 +1,24 @@ +// build-pass + +use spirv_std::ByteAddressableBuffer; + +pub struct BigStruct { + a: u32, + b: u32, + c: u32, + d: u32, + e: u32, + f: u32, +} + +#[spirv(fragment)] +pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut BigStruct) { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); +} + +#[spirv(fragment)] +pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: BigStruct) { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); +} diff --git a/tests/ui/byte_addressable_buffer/complex.rs b/tests/ui/byte_addressable_buffer/complex.rs new file mode 100644 index 0000000000..c38e677f78 --- /dev/null +++ b/tests/ui/byte_addressable_buffer/complex.rs @@ -0,0 +1,30 @@ +// build-pass + +use spirv_std::{glam::Vec2, ByteAddressableBuffer}; + +pub struct Complex { + x: u32, + y: f32, + n: Nesty, + v: Vec2, + a: [f32; 7], + m: [Nesty; 2], +} + +pub struct Nesty { + x: f32, + y: f32, + z: f32, +} + +#[spirv(fragment)] +pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut Nesty) { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); +} + +#[spirv(fragment)] +pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: Nesty) { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); +} diff --git a/tests/ui/byte_addressable_buffer/f32.rs b/tests/ui/byte_addressable_buffer/f32.rs new file mode 100644 index 0000000000..aa93e8994f --- /dev/null +++ b/tests/ui/byte_addressable_buffer/f32.rs @@ -0,0 +1,15 @@ +// build-pass + +use spirv_std::ByteAddressableBuffer; + +#[spirv(fragment)] +pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut f32) { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); +} + +#[spirv(fragment)] +pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: f32) { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); +} diff --git a/tests/ui/byte_addressable_buffer/u32.rs b/tests/ui/byte_addressable_buffer/u32.rs new file mode 100644 index 0000000000..a0a53592b5 --- /dev/null +++ b/tests/ui/byte_addressable_buffer/u32.rs @@ -0,0 +1,15 @@ +// build-pass + +use spirv_std::ByteAddressableBuffer; + +#[spirv(fragment)] +pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut u32) { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); +} + +#[spirv(fragment)] +pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: u32) { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); +} diff --git a/tests/ui/byte_addressable_buffer/vec.rs b/tests/ui/byte_addressable_buffer/vec.rs new file mode 100644 index 0000000000..3f09dc207c --- /dev/null +++ b/tests/ui/byte_addressable_buffer/vec.rs @@ -0,0 +1,15 @@ +// build-pass + +use spirv_std::{glam::Vec4, ByteAddressableBuffer}; + +#[spirv(fragment)] +pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut Vec4) { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); +} + +#[spirv(fragment)] +pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: Vec4) { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); +} From 6b5e9f0e80093aec9b3564b4fed0664e005601c8 Mon Sep 17 00:00:00 2001 From: khyperia <953151+khyperia@users.noreply.github.com> Date: Wed, 25 Aug 2021 10:39:43 +0200 Subject: [PATCH 2/3] Make ByteAddressableBuffer unsafe --- .../spirv-std/src/byte_addressable_buffer.rs | 18 +++++++++++---- tests/ui/byte_addressable_buffer/arr.rs | 22 ++++++++++++++----- .../ui/byte_addressable_buffer/big_struct.rs | 22 ++++++++++++++----- tests/ui/byte_addressable_buffer/complex.rs | 22 ++++++++++++++----- tests/ui/byte_addressable_buffer/f32.rs | 19 +++++++++++----- tests/ui/byte_addressable_buffer/u32.rs | 19 +++++++++++----- tests/ui/byte_addressable_buffer/vec.rs | 19 +++++++++++----- 7 files changed, 101 insertions(+), 40 deletions(-) diff --git a/crates/spirv-std/src/byte_addressable_buffer.rs b/crates/spirv-std/src/byte_addressable_buffer.rs index 81688541ba..39275fe1a9 100644 --- a/crates/spirv-std/src/byte_addressable_buffer.rs +++ b/crates/spirv-std/src/byte_addressable_buffer.rs @@ -38,19 +38,29 @@ impl<'a> ByteAddressableBuffer<'a> { /// Loads an arbitrary type from the buffer. `byte_index` must be a multiple of 4, otherwise, /// it will get silently rounded down to the nearest multiple of 4. - pub fn load(self, byte_index: u32) -> T { + /// + /// # Safety + /// This function allows writing a type to an untyped buffer, then reading a different type + /// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a + /// transmute) + pub unsafe fn load(self, byte_index: u32) -> T { if byte_index + mem::size_of::() as u32 > self.data.len() as u32 { panic!("Index out of range") } - unsafe { buffer_load_intrinsic(self.data, byte_index) } + buffer_load_intrinsic(self.data, byte_index) } /// Stores an arbitrary type int the buffer. `byte_index` must be a multiple of 4, otherwise, /// it will get silently rounded down to the nearest multiple of 4. - pub fn store(self, byte_index: u32, value: T) { + /// + /// # Safety + /// This function allows writing a type to an untyped buffer, then reading a different type + /// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a + /// transmute) + pub unsafe fn store(self, byte_index: u32, value: T) { if byte_index + mem::size_of::() as u32 > self.data.len() as u32 { panic!("Index out of range") } - unsafe { buffer_store_intrinsic(self.data, byte_index, value) } + buffer_store_intrinsic(self.data, byte_index, value); } } diff --git a/tests/ui/byte_addressable_buffer/arr.rs b/tests/ui/byte_addressable_buffer/arr.rs index a6f35314e9..d94c7fddd7 100644 --- a/tests/ui/byte_addressable_buffer/arr.rs +++ b/tests/ui/byte_addressable_buffer/arr.rs @@ -3,13 +3,23 @@ use spirv_std::{glam::Vec4, ByteAddressableBuffer}; #[spirv(fragment)] -pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut [i32; 4]) { - let buf = ByteAddressableBuffer::new(buf); - *out = buf.load(5); +pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + out: &mut [i32; 4], +) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); + } } #[spirv(fragment)] -pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: [i32; 4]) { - let buf = ByteAddressableBuffer::new(buf); - buf.store(5, val); +pub fn store( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + val: [i32; 4], +) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); + } } diff --git a/tests/ui/byte_addressable_buffer/big_struct.rs b/tests/ui/byte_addressable_buffer/big_struct.rs index bfcbe79719..f3ab2195d4 100644 --- a/tests/ui/byte_addressable_buffer/big_struct.rs +++ b/tests/ui/byte_addressable_buffer/big_struct.rs @@ -12,13 +12,23 @@ pub struct BigStruct { } #[spirv(fragment)] -pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut BigStruct) { - let buf = ByteAddressableBuffer::new(buf); - *out = buf.load(5); +pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + out: &mut BigStruct, +) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); + } } #[spirv(fragment)] -pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: BigStruct) { - let buf = ByteAddressableBuffer::new(buf); - buf.store(5, val); +pub fn store( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + val: BigStruct, +) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); + } } diff --git a/tests/ui/byte_addressable_buffer/complex.rs b/tests/ui/byte_addressable_buffer/complex.rs index c38e677f78..2f6037071a 100644 --- a/tests/ui/byte_addressable_buffer/complex.rs +++ b/tests/ui/byte_addressable_buffer/complex.rs @@ -18,13 +18,23 @@ pub struct Nesty { } #[spirv(fragment)] -pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut Nesty) { - let buf = ByteAddressableBuffer::new(buf); - *out = buf.load(5); +pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + out: &mut Nesty, +) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); + } } #[spirv(fragment)] -pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: Nesty) { - let buf = ByteAddressableBuffer::new(buf); - buf.store(5, val); +pub fn store( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + val: Nesty, +) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); + } } diff --git a/tests/ui/byte_addressable_buffer/f32.rs b/tests/ui/byte_addressable_buffer/f32.rs index aa93e8994f..9fcf205886 100644 --- a/tests/ui/byte_addressable_buffer/f32.rs +++ b/tests/ui/byte_addressable_buffer/f32.rs @@ -3,13 +3,20 @@ use spirv_std::ByteAddressableBuffer; #[spirv(fragment)] -pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut f32) { - let buf = ByteAddressableBuffer::new(buf); - *out = buf.load(5); +pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + out: &mut f32, +) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); + } } #[spirv(fragment)] -pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: f32) { - let buf = ByteAddressableBuffer::new(buf); - buf.store(5, val); +pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: f32) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); + } } diff --git a/tests/ui/byte_addressable_buffer/u32.rs b/tests/ui/byte_addressable_buffer/u32.rs index a0a53592b5..8fccbfebf7 100644 --- a/tests/ui/byte_addressable_buffer/u32.rs +++ b/tests/ui/byte_addressable_buffer/u32.rs @@ -3,13 +3,20 @@ use spirv_std::ByteAddressableBuffer; #[spirv(fragment)] -pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut u32) { - let buf = ByteAddressableBuffer::new(buf); - *out = buf.load(5); +pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + out: &mut u32, +) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); + } } #[spirv(fragment)] -pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: u32) { - let buf = ByteAddressableBuffer::new(buf); - buf.store(5, val); +pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: u32) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); + } } diff --git a/tests/ui/byte_addressable_buffer/vec.rs b/tests/ui/byte_addressable_buffer/vec.rs index 3f09dc207c..41ea2664d0 100644 --- a/tests/ui/byte_addressable_buffer/vec.rs +++ b/tests/ui/byte_addressable_buffer/vec.rs @@ -3,13 +3,20 @@ use spirv_std::{glam::Vec4, ByteAddressableBuffer}; #[spirv(fragment)] -pub fn load(#[spirv(storage_buffer)] buf: &mut [u32], out: &mut Vec4) { - let buf = ByteAddressableBuffer::new(buf); - *out = buf.load(5); +pub fn load( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], + out: &mut Vec4, +) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + *out = buf.load(5); + } } #[spirv(fragment)] -pub fn store(#[spirv(storage_buffer)] buf: &mut [u32], val: Vec4) { - let buf = ByteAddressableBuffer::new(buf); - buf.store(5, val); +pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: Vec4) { + unsafe { + let buf = ByteAddressableBuffer::new(buf); + buf.store(5, val); + } } From c36203a9e87615a00bcf2256a034a26b3ad62143 Mon Sep 17 00:00:00 2001 From: Ashley Hauck <953151+khyperia@users.noreply.github.com> Date: Thu, 26 Aug 2021 14:28:44 +0200 Subject: [PATCH 3/3] Update crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs Co-authored-by: Markus Siglreithmaier --- .../rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs index 0450d370fd..607b5bf310 100644 --- a/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs +++ b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs @@ -332,7 +332,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // Signature: fn store(array: &[u32], index: u32, value: T); if args.len() != 4 { self.fatal(&format!( - "buffer_load_intrinsic should have 4 args, it has {}", + "buffer_store_intrinsic should have 4 args, it has {}", args.len() )); }