Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsafe improvements: core parquet crate. #6024

Merged
merged 3 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion parquet/src/bloom_filter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ impl Block {

#[inline]
fn to_ne_bytes(self) -> [u8; 32] {
unsafe { std::mem::transmute(self) }
// SAFETY: [u32; 8] and [u8; 32] have the same size and neither has invalid bit patterns.
unsafe { std::mem::transmute(self.0) }
}

#[inline]
Expand Down
32 changes: 22 additions & 10 deletions parquet/src/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,8 @@ macro_rules! gen_as_bytes {
impl AsBytes for $source_ty {
#[allow(clippy::size_of_in_element_count)]
fn as_bytes(&self) -> &[u8] {
// SAFETY: macro is only used with primitive types that have no padding, so the
// resulting slice always refers to initialized memory.
unsafe {
std::slice::from_raw_parts(
self as *const $source_ty as *const u8,
Expand All @@ -481,6 +483,8 @@ macro_rules! gen_as_bytes {
#[inline]
#[allow(clippy::size_of_in_element_count)]
fn slice_as_bytes(self_: &[Self]) -> &[u8] {
// SAFETY: macro is only used with primitive types that have no padding, so the
// resulting slice always refers to initialized memory.
unsafe {
std::slice::from_raw_parts(
self_.as_ptr() as *const u8,
Expand All @@ -492,10 +496,15 @@ macro_rules! gen_as_bytes {
#[inline]
#[allow(clippy::size_of_in_element_count)]
unsafe fn slice_as_bytes_mut(self_: &mut [Self]) -> &mut [u8] {
std::slice::from_raw_parts_mut(
self_.as_mut_ptr() as *mut u8,
std::mem::size_of_val(self_),
)
// SAFETY: macro is only used with primitive types that have no padding, so the
// resulting slice always refers to initialized memory. Moreover, self has no
// invalid bit patterns, so all writes to the resulting slice will be valid.
unsafe {
std::slice::from_raw_parts_mut(
self_.as_mut_ptr() as *mut u8,
std::mem::size_of_val(self_),
)
}
}
}
};
Expand Down Expand Up @@ -534,12 +543,15 @@ unimplemented_slice_as_bytes!(FixedLenByteArray);

impl AsBytes for bool {
fn as_bytes(&self) -> &[u8] {
// SAFETY: a bool is guaranteed to be either 0x00 or 0x01 in memory, so the memory is
// valid.
unsafe { std::slice::from_raw_parts(self as *const bool as *const u8, 1) }
}
}

impl AsBytes for Int96 {
fn as_bytes(&self) -> &[u8] {
// SAFETY: Int96::data is a &[u32; 3].
unsafe { std::slice::from_raw_parts(self.data() as *const [u32] as *const u8, 12) }
}
}
Expand Down Expand Up @@ -718,6 +730,7 @@ pub(crate) mod private {

#[inline]
fn encode<W: std::io::Write>(values: &[Self], writer: &mut W, _: &mut BitWriter) -> Result<()> {
// SAFETY: Self is one of i32, i64, f32, f64, which have no padding.
let raw = unsafe {
std::slice::from_raw_parts(
values.as_ptr() as *const u8,
Expand Down Expand Up @@ -747,9 +760,10 @@ pub(crate) mod private {
return Err(eof_err!("Not enough bytes to decode"));
}

// SAFETY: Raw types should be as per the standard rust bit-vectors
unsafe {
let raw_buffer = &mut Self::slice_as_bytes_mut(buffer)[..bytes_to_decode];
{
// SAFETY: Self has no invalid bit patterns, so writing to the slice
// obtained with slice_as_bytes_mut is always safe.
let raw_buffer = &mut unsafe { Self::slice_as_bytes_mut(buffer) }[..bytes_to_decode];
raw_buffer.copy_from_slice(data.slice(
decoder.start..decoder.start + bytes_to_decode
).as_ref());
Expand Down Expand Up @@ -810,9 +824,7 @@ pub(crate) mod private {
_: &mut BitWriter,
) -> Result<()> {
for value in values {
let raw = unsafe {
std::slice::from_raw_parts(value.data() as *const [u32] as *const u8, 12)
};
let raw = SliceAsBytes::slice_as_bytes(value.data());
writer.write_all(raw)?;
}
Ok(())
Expand Down
45 changes: 39 additions & 6 deletions parquet/src/util/bit_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,11 @@ fn array_from_slice<const N: usize>(bs: &[u8]) -> Result<[u8; N]> {
}
}

pub trait FromBytes: Sized {
/// # Safety
/// All bit patterns 00000xxxx, where there are `BIT_CAPACITY` `x`s,
/// must be valid, unless BIT_CAPACITY is 0.
pub unsafe trait FromBytes: Sized {
const BIT_CAPACITY: usize;
type Buffer: AsMut<[u8]> + Default;
fn try_from_le_slice(b: &[u8]) -> Result<Self>;
fn from_le_bytes(bs: Self::Buffer) -> Self;
Expand All @@ -51,7 +55,9 @@ pub trait FromBytes: Sized {
macro_rules! from_le_bytes {
($($ty: ty),*) => {
$(
impl FromBytes for $ty {
// SAFETY: this macro is used for types for which all bit patterns are valid.
unsafe impl FromBytes for $ty {
const BIT_CAPACITY: usize = std::mem::size_of::<$ty>() * 8;
type Buffer = [u8; size_of::<Self>()];
fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Ok(Self::from_le_bytes(array_from_slice(b)?))
Expand All @@ -66,7 +72,9 @@ macro_rules! from_le_bytes {

from_le_bytes! { u8, u16, u32, u64, i8, i16, i32, i64, f32, f64 }

impl FromBytes for bool {
// SAFETY: the 0000000x bit pattern is always valid for `bool`.
unsafe impl FromBytes for bool {
const BIT_CAPACITY: usize = 1;
type Buffer = [u8; 1];

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand All @@ -77,7 +85,9 @@ impl FromBytes for bool {
}
}

impl FromBytes for Int96 {
// SAFETY: BIT_CAPACITY is 0.
unsafe impl FromBytes for Int96 {
const BIT_CAPACITY: usize = 0;
type Buffer = [u8; 12];

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand All @@ -95,7 +105,9 @@ impl FromBytes for Int96 {
}
}

impl FromBytes for ByteArray {
// SAFETY: BIT_CAPACITY is 0.
unsafe impl FromBytes for ByteArray {
const BIT_CAPACITY: usize = 0;
type Buffer = Vec<u8>;

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand All @@ -106,7 +118,9 @@ impl FromBytes for ByteArray {
}
}

impl FromBytes for FixedLenByteArray {
// SAFETY: BIT_CAPACITY is 0.
unsafe impl FromBytes for FixedLenByteArray {
const BIT_CAPACITY: usize = 0;
type Buffer = Vec<u8>;

fn try_from_le_slice(b: &[u8]) -> Result<Self> {
Expand Down Expand Up @@ -457,10 +471,17 @@ impl BitReader {
}
}

assert_ne!(T::BIT_CAPACITY, 0);
assert!(num_bits <= T::BIT_CAPACITY);

// Read directly into output buffer
match size_of::<T>() {
1 => {
let ptr = batch.as_mut_ptr() as *mut u8;
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 8 {
let out_slice = (&mut out[i..i + 8]).try_into().unwrap();
Expand All @@ -471,6 +492,10 @@ impl BitReader {
}
2 => {
let ptr = batch.as_mut_ptr() as *mut u16;
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 16 {
let out_slice = (&mut out[i..i + 16]).try_into().unwrap();
Expand All @@ -481,6 +506,10 @@ impl BitReader {
}
4 => {
let ptr = batch.as_mut_ptr() as *mut u32;
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 32 {
let out_slice = (&mut out[i..i + 32]).try_into().unwrap();
Expand All @@ -491,6 +520,10 @@ impl BitReader {
}
8 => {
let ptr = batch.as_mut_ptr() as *mut u64;
// SAFETY: batch is properly aligned and sized. Caller guarantees that all bit patterns
// in which only the lowest T::BIT_CAPACITY bits of T are set are valid,
// unpack{8,16,32,64} only set to non0 the lowest num_bits bits, and we
// checked that num_bits <= T::BIT_CAPACITY.
let out = unsafe { std::slice::from_raw_parts_mut(ptr, batch.len()) };
while values_to_read - i >= 64 {
let out_slice = (&mut out[i..i + 64]).try_into().unwrap();
Expand Down
Loading