Skip to content

Commit

Permalink
Make FromBytes an unsafe trait.
Browse files Browse the repository at this point in the history
  • Loading branch information
veluca93 committed Jul 9, 2024
1 parent 52db15e commit aab8250
Showing 1 changed file with 39 additions and 22 deletions.
61 changes: 39 additions & 22 deletions parquet/src/util/bit_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ fn array_from_slice<const N: usize>(bs: &[u8]) -> Result<[u8; N]> {
}
}

pub trait FromBytes: Sized {
// # Safety
// All bit patterns 00000xxxx, where there are `BIT_CAPACITY` `x`s,
// must be valid, unless BIT_CAPACITY is 0.
pub unsafe trait FromBytes: Sized {
const BIT_CAPACITY: usize;
type Buffer: AsMut<[u8]> + Default;
fn try_from_le_slice(b: &[u8]) -> Result<Self>;
fn from_le_bytes(bs: Self::Buffer) -> Self;
Expand All @@ -51,7 +55,9 @@ pub trait FromBytes: Sized {
macro_rules! from_le_bytes {
($($ty: ty),*) => {
$(
impl FromBytes for $ty {
// SAFETY: this macro is used for types for which all bit patterns are valid.
unsafe impl FromBytes for $ty {
const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8;
type Buffer = [u8; size_of::<Self>()];
fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Ok(Self::from_le_bytes(array_from_slice(b)?))
Expand All @@ -66,7 +72,9 @@ macro_rules! from_le_bytes {

from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64, f32, f64 }

impl FromBytes for bool {
// SAFETY: the 0000000x bit pattern is always valid for `bool`.
unsafe impl FromBytes for bool {
const BIT_CAPACITY: usize = 1;
type Buffer = [u8; 1];

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand All @@ -77,7 +85,9 @@ impl FromBytes for bool {
}
}

impl FromBytes for Int96 {
// SAFETY: BIT_CAPACITY is 0.
unsafe impl FromBytes for Int96 {
const BIT_CAPACITY: usize = 0;
type Buffer = [u8; 12];

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand All @@ -95,7 +105,9 @@ impl FromBytes for Int96 {
}
}

impl FromBytes for ByteArray {
// SAFETY: BIT_CAPACITY is 0.
unsafe impl FromBytes for ByteArray {
const BIT_CAPACITY: usize = 0;
type Buffer = Vec<u8>;

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand All @@ -106,7 +118,9 @@ impl FromBytes for ByteArray {
}
}

impl FromBytes for FixedLenByteArray {
// SAFETY: BIT_CAPACITY is 0.
unsafe impl FromBytes for FixedLenByteArray {
const BIT_CAPACITY: usize = 0;
type Buffer = Vec<u8>;

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand Down Expand Up @@ -435,10 +449,6 @@ impl BitReader {
/// This function panics if
/// - `num_bits` is larger than the bit-capacity of `T`
///
// FIXME: soundness issue - this method can be used to write arbitrary bytes to any
// T. A possible fix would be to make `FromBytes` an unsafe trait (or to use a
// separate marker trait) which requires all bit patterns of T to be valid (note that this is
// not the case for `T` = `bool`).
pub fn get_batch<T: FromBytes>(&mut self, batch: &mut [T], num_bits: usize) -> usize {
assert!(num_bits <= size_of::<T>() * 8);

Expand All @@ -461,13 +471,17 @@ impl BitReader {
}
}

assert_ne!(T::BIT_CAPACITY, 0);
assert!(num_bits <= T::BIT_CAPACITY);

// Read directly into output buffer
match size_of::<T>() {
1 => {
let ptr = batch.as_mut_ptr() as *mut u8;
// SAFETY: batch is properly aligned and sized. Caller guarantees that T
// can be safely seen as a slice of bytes through FromBytes bound
// (FIXME: not actually true right now)
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 8 {
let out_slice = (&mut out[i..i + 8]).try_into().unwrap();
Expand All @@ -478,9 +492,10 @@ impl BitReader {
}
2 => {
let ptr = batch.as_mut_ptr() as *mut u16;
// SAFETY: batch is properly aligned and sized. Caller guarantees that T
// can be safely seen as a slice of bytes through FromBytes bound
// (FIXME: not actually true right now)
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 16 {
let out_slice = (&mut out[i..i + 16]).try_into().unwrap();
Expand All @@ -491,9 +506,10 @@ impl BitReader {
}
4 => {
let ptr = batch.as_mut_ptr() as *mut u32;
// SAFETY: batch is properly aligned and sized. Caller guarantees that T
// can be safely seen as a slice of bytes through FromBytes bound
// (FIXME: not actually true right now)
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 32 {
let out_slice = (&mut out[i..i + 32]).try_into().unwrap();
Expand All @@ -504,9 +520,10 @@ impl BitReader {
}
8 => {
let ptr = batch.as_mut_ptr() as *mut u64;
// SAFETY: batch is properly aligned and sized. Caller guarantees that T
// can be safely seen as a slice of bytes through FromBytes bound
// (FIXME: not actually true right now)
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 64 {
let out_slice = (&mut out[i..i + 64]).try_into().unwrap();
Expand Down

0 comments on commit aab8250

Please sign in to comment.