Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal async decoder support #46

Draft
wants to merge 10 commits into
base: master
Choose a base branch
from
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ edition = "2018"
arrayref = "0.3.5"
arrayvec = "0.7.1"
blake3 = "1.0.0"
tokio = { version = "1.24.2", features = [], default-features=false, optional = true }

[features]
# todo: remove before merge
default = ["tokio_io"]
tokio_io = ["tokio"]

[dev-dependencies]
lazy_static = "1.3.0"
Expand All @@ -23,3 +29,4 @@ tempfile = "3.1.0"
rand_chacha = "0.3.1"
rand_xorshift = "0.3.0"
page_size = "0.4.1"
tokio = { version = "1.24.2", features = ["full"] }
232 changes: 227 additions & 5 deletions src/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ impl From<Error> for io::Error {

// Shared between Decoder and SliceDecoder.
#[derive(Clone)]
struct DecoderShared<T: Read, O: Read> {
struct DecoderShared<T, O> {
input: T,
outboard: Option<O>,
state: VerifyState,
Expand All @@ -214,7 +214,7 @@ struct DecoderShared<T: Read, O: Read> {
buf_end: usize,
}

impl<T: Read, O: Read> DecoderShared<T, O> {
impl<T, O> DecoderShared<T, O> {
fn new(input: T, outboard: Option<O>, hash: &Hash) -> Self {
Self {
input,
Expand All @@ -240,7 +240,9 @@ impl<T: Read, O: Read> DecoderShared<T, O> {
self.buf_start = 0;
self.buf_end = 0;
}
}

impl<T: Read, O: Read> DecoderShared<T, O> {
// These bytes are always verified before going in the buffer.
fn take_buffered_bytes(&mut self, output: &mut [u8]) -> usize {
let take = cmp::min(self.buf_len(), output.len());
Expand Down Expand Up @@ -441,7 +443,7 @@ impl<T: Read + Seek, O: Read + Seek> DecoderShared<T, O> {
}
}

impl<T: Read, O: Read> fmt::Debug for DecoderShared<T, O> {
impl<T, O> fmt::Debug for DecoderShared<T, O> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
Expand All @@ -454,6 +456,226 @@ impl<T: Read, O: Read> fmt::Debug for DecoderShared<T, O> {
}
}

#[cfg(feature = "tokio_io")]
mod tokio_io {
use super::{DecoderShared, Hash, NextRead};
use std::{
cmp,
convert::TryInto,
io,
pin::Pin,
task::{self, ready},
};
use tokio::io::{AsyncRead, ReadBuf};

// tokio flavour async io utilities, requiing AsyncRead
impl<T: AsyncRead + Unpin, O: AsyncRead + Unpin> DecoderShared<T, O> {
fn poll_read(
&mut self,
cx: &mut task::Context,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
// Explicitly short-circuit zero-length reads. We're within our rights
// to buffer an internal chunk in this case, or to make progress if
// there's an empty chunk, but this matches the current behavior of
// SliceExtractor for zero-length slices. This might change in the
// future.
if buf.remaining() == 0 {
return task::Poll::Ready(Ok(()));
}

// Otherwise try to verify a new chunk.
loop {
// If there are bytes in the internal buffer, just return those.
if self.buf_len() > 0 {
let n = cmp::min(buf.remaining(), self.buf_len());
buf.put_slice(&self.buf[self.buf_start..self.buf_start + n]);
self.buf_start += n;
// if we are done with writing, go into the reading state
if self.buf_len() == 0 {
self.clear_buf();
}
return task::Poll::Ready(Ok(()));
}

match self.state.read_next() {
NextRead::Done => {
// This is EOF. We know the internal buffer is empty,
// because we checked it before this loop.
return task::Poll::Ready(Ok(()));
}
NextRead::Header => {
// ensure reading state, reading 8 bytes
// we might already be in the reading state,
// so we must not set buf_start to 0
self.buf_end = 8;
// header comes from outboard if we have one, otherwise from input
ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?;
self.state.feed_header(self.buf[0..8].try_into().unwrap());
// we don't want to write the header, so we are done with the buffer contents
self.clear_buf();
}
NextRead::Parent => {
// ensure reading state, reading 64 bytes
// we might already be in the reading state,
// so we must not set buf_start to 0
self.buf_end = 64;
// parent comes from outboard if we have one, otherwise from input
ready!(self.poll_fill_buffer_from_input_or_outboard(cx))?;
self.state
.feed_parent(&self.buf[0..64].try_into().unwrap())?;
// we don't want to write the parent, so we are done with the buffer contents
self.clear_buf();
}
NextRead::Chunk {
size,
finalization,
skip,
index,
} => {
// todo: add direct output optimization

// ensure reading state, reading size bytes
// we might already be in the reading state,
// so we must not set buf_start to 0
self.buf_end = size;
// chunk never comes from outboard
ready!(self.poll_fill_buffer_from_input(cx))?;

// Hash it and push its hash into the VerifyState. This
// returns an error if the hash is bad. Otherwise, the
// chunk is verified.
let read_buf = &self.buf[0..size];
let chunk_hash = blake3::guts::ChunkState::new(index)
.update(read_buf)
.finalize(finalization.is_root());
self.state.feed_chunk(&chunk_hash)?;

// we go into the writing state now, starting from skip
self.buf_start = skip;
// we should have something to write,
// unless the entire chunk was empty
debug_assert!(self.buf_len() > 0 || size == 0);
}
}
}
}

fn poll_fill_buffer_from_input(
&mut self,
cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), io::Error>> {
let mut buf = ReadBuf::new(&mut self.buf[..self.buf_end]);
buf.advance(self.buf_start);
let src = &mut self.input;
while buf.remaining() > 0 {
ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?;
self.buf_start = buf.filled().len();
}
task::Poll::Ready(Ok(()))
}

fn poll_fill_buffer_from_outboard(
&mut self,
cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), io::Error>> {
let mut buf = ReadBuf::new(&mut self.buf[..self.buf_end]);
buf.advance(self.buf_start);
let src = self.outboard.as_mut().unwrap();
while buf.remaining() > 0 {
ready!(AsyncRead::poll_read(Pin::new(src), cx, &mut buf))?;
self.buf_start = buf.filled().len();
}
task::Poll::Ready(Ok(()))
}

fn poll_fill_buffer_from_input_or_outboard(
&mut self,
cx: &mut task::Context<'_>,
) -> task::Poll<Result<(), io::Error>> {
if self.outboard.is_some() {
self.poll_fill_buffer_from_outboard(cx)
} else {
self.poll_fill_buffer_from_input(cx)
}
}
}

#[derive(Clone, Debug)]
pub struct AsyncDecoder<T: AsyncRead + Unpin, O: AsyncRead + Unpin> {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would also be possible to have the Decoder support both async and sync, but the ergonomics of that is really bad. I tried it. So I went with this tiny wrapper instead.

shared: DecoderShared<T, O>,
}

impl<T: AsyncRead + Unpin> AsyncDecoder<T, T> {
pub fn new(inner: T, hash: &Hash) -> Self {
Self {
shared: DecoderShared::new(inner, None, hash),
}
}
}

impl<T: AsyncRead + Unpin, O: AsyncRead + Unpin> AsyncDecoder<T, O> {
pub fn new_outboard(inner: T, outboard: O, hash: &Hash) -> Self {
Self {
shared: DecoderShared::new(inner, Some(outboard), hash),
}
}

/// Return the underlying reader and the outboard reader, if any. If the `Decoder` was created
/// with `Decoder::new`, the outboard reader will be `None`.
pub fn into_inner(self) -> (T, Option<O>) {
(self.shared.input, self.shared.outboard)
}
}

impl<T: AsyncRead + Unpin, O: AsyncRead + Unpin> AsyncRead for AsyncDecoder<T, O> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> task::Poll<io::Result<()>> {
self.shared.poll_read(cx, buf)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{decode::make_test_input, encode};

#[tokio::test]
async fn test_async_decode() {
for &case in crate::test::TEST_CASES {
use tokio::io::AsyncReadExt;
println!("case {}", case);
let input = make_test_input(case);
let (encoded, hash) = { encode::encode(&input) };
let mut output = Vec::new();
let mut reader = AsyncDecoder::new(&encoded[..], &hash);
reader.read_to_end(&mut output).await.unwrap();
assert_eq!(input, output);
}
}

#[tokio::test]
async fn test_async_decode_outboard() {
for &case in crate::test::TEST_CASES {
use tokio::io::AsyncReadExt;
println!("case {}", case);
let input = make_test_input(case);
let (outboard, hash) = { encode::outboard(&input) };
let mut output = Vec::new();
let mut reader = AsyncDecoder::new_outboard(&input[..], &outboard[..], &hash);
reader.read_to_end(&mut output).await.unwrap();
assert_eq!(input, output);
}
}
}
}

#[cfg(feature = "tokio_io")]
pub use tokio_io::AsyncDecoder;

/// An incremental decoder, which reads and verifies the output of
/// [`Encoder`](../encode/struct.Encoder.html).
///
Expand Down Expand Up @@ -861,7 +1083,7 @@ mod test {
let mut output = Vec::new();
let mut decoder = Decoder::new(&*zero_encoded, &zero_hash);
decoder.read_to_end(&mut output).unwrap();
assert_eq!(&output, &[]);
assert_eq!(output.len(), 0);

// Decoding the empty tree with any other hash should fail.
let mut output = Vec::new();
Expand Down Expand Up @@ -936,7 +1158,7 @@ mod test {
let mut decoder = Decoder::new(Cursor::new(&encoded), &hash);
decoder.seek(SeekFrom::Start(case as u64)).unwrap();
decoder.read_to_end(&mut output).unwrap();
assert_eq!(&output, &[]);
assert_eq!(output.len(), 0);

// Seeking to EOF should fail if the root hash is wrong.
let mut bad_hash_bytes = *hash.as_bytes();
Expand Down
2 changes: 1 addition & 1 deletion src/encode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ mod test {
let mut output = Vec::new();
let mut encoder = Encoder::new(io::Cursor::new(&mut output));
encoder.write_all(input).unwrap();
encoder.write(&[]).unwrap();
encoder.write_all(&[]).unwrap();
let hash = encoder.finalize().unwrap();
assert_eq!((output, hash), encode(input));
assert_eq!(hash, blake3::hash(input));
Expand Down