Skip to content

Commit

Permalink
Auto merge of rust-lang#3809 - RalfJung:fd-refcell, r=oli-obk
Browse files Browse the repository at this point in the history
FD: remove big surrounding RefCell, simplify socketpair

A while ago, I added the big implicit RefCell for all file descriptions since it avoided interior mutability in `eventfd`. However, this requires us to hold the RefCell "lock" around the entire invocation of the `read`/`write` methods on an FD, which is not great. For instance, if an FD wants to update epoll notifications from inside its `read`/`write`, it is very crucial that the notification check does not end up accessing the FD itself. Such cycles, however, occur naturally:
- eventfd wants to update notifications for itself
- socketfd wants to update notifications on its "peer", which will in turn check *its* peer to see whether that buffer is empty -- and my peer's peer is myself.

This then also lets us simplify socketpair, which currently holds a weak reference to its peer *and* a weak reference to the peer's buffer -- that was previously needed precisely to avoid this issue.
  • Loading branch information
bors committed Aug 16, 2024
2 parents 1a51dd9 + 883e477 commit 83f1b38
Show file tree
Hide file tree
Showing 6 changed files with 294 additions and 399 deletions.
121 changes: 51 additions & 70 deletions src/tools/miri/src/shims/unix/fd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
//! standard file descriptors (stdin/stdout/stderr).
use std::any::Any;
use std::cell::{Ref, RefCell, RefMut};
use std::collections::BTreeMap;
use std::io::{self, ErrorKind, IsTerminal, Read, SeekFrom, Write};
use std::ops::Deref;
use std::rc::Rc;
use std::rc::Weak;

Expand All @@ -27,9 +27,9 @@ pub trait FileDescription: std::fmt::Debug + Any {

/// Reads as much as possible into the given buffer, and returns the number of bytes read.
fn read<'tcx>(
&mut self,
&self,
_self_ref: &FileDescriptionRef,
_communicate_allowed: bool,
_fd_id: FdId,
_bytes: &mut [u8],
_ecx: &mut MiriInterpCx<'tcx>,
) -> InterpResult<'tcx, io::Result<usize>> {
Expand All @@ -38,9 +38,9 @@ pub trait FileDescription: std::fmt::Debug + Any {

/// Writes as much as possible from the given buffer, and returns the number of bytes written.
fn write<'tcx>(
&mut self,
&self,
_self_ref: &FileDescriptionRef,
_communicate_allowed: bool,
_fd_id: FdId,
_bytes: &[u8],
_ecx: &mut MiriInterpCx<'tcx>,
) -> InterpResult<'tcx, io::Result<usize>> {
Expand All @@ -50,7 +50,7 @@ pub trait FileDescription: std::fmt::Debug + Any {
/// Reads as much as possible into the given buffer from a given offset,
/// and returns the number of bytes read.
fn pread<'tcx>(
&mut self,
&self,
_communicate_allowed: bool,
_bytes: &mut [u8],
_offset: u64,
Expand All @@ -62,7 +62,7 @@ pub trait FileDescription: std::fmt::Debug + Any {
/// Writes as much as possible from the given buffer starting at a given offset,
/// and returns the number of bytes written.
fn pwrite<'tcx>(
&mut self,
&self,
_communicate_allowed: bool,
_bytes: &[u8],
_offset: u64,
Expand All @@ -74,7 +74,7 @@ pub trait FileDescription: std::fmt::Debug + Any {
/// Seeks to the given offset (which can be relative to the beginning, end, or current position).
/// Returns the new position from the start of the stream.
fn seek<'tcx>(
&mut self,
&self,
_communicate_allowed: bool,
_offset: SeekFrom,
) -> InterpResult<'tcx, io::Result<u64>> {
Expand Down Expand Up @@ -111,14 +111,9 @@ pub trait FileDescription: std::fmt::Debug + Any {

impl dyn FileDescription {
#[inline(always)]
pub fn downcast_ref<T: Any>(&self) -> Option<&T> {
pub fn downcast<T: Any>(&self) -> Option<&T> {
(self as &dyn Any).downcast_ref()
}

#[inline(always)]
pub fn downcast_mut<T: Any>(&mut self) -> Option<&mut T> {
(self as &mut dyn Any).downcast_mut()
}
}

impl FileDescription for io::Stdin {
Expand All @@ -127,17 +122,17 @@ impl FileDescription for io::Stdin {
}

fn read<'tcx>(
&mut self,
&self,
_self_ref: &FileDescriptionRef,
communicate_allowed: bool,
_fd_id: FdId,
bytes: &mut [u8],
_ecx: &mut MiriInterpCx<'tcx>,
) -> InterpResult<'tcx, io::Result<usize>> {
if !communicate_allowed {
// We want isolation mode to be deterministic, so we have to disallow all reads, even stdin.
helpers::isolation_abort_error("`read` from stdin")?;
}
Ok(Read::read(self, bytes))
Ok(Read::read(&mut { self }, bytes))
}

fn is_tty(&self, communicate_allowed: bool) -> bool {
Expand All @@ -151,14 +146,14 @@ impl FileDescription for io::Stdout {
}

fn write<'tcx>(
&mut self,
&self,
_self_ref: &FileDescriptionRef,
_communicate_allowed: bool,
_fd_id: FdId,
bytes: &[u8],
_ecx: &mut MiriInterpCx<'tcx>,
) -> InterpResult<'tcx, io::Result<usize>> {
// We allow writing to stderr even with isolation enabled.
let result = Write::write(self, bytes);
let result = Write::write(&mut { self }, bytes);
// Stdout is buffered, flush to make sure it appears on the
// screen. This is the write() syscall of the interpreted
// program, we want it to correspond to a write() syscall on
Expand All @@ -180,9 +175,9 @@ impl FileDescription for io::Stderr {
}

fn write<'tcx>(
&mut self,
&self,
_self_ref: &FileDescriptionRef,
_communicate_allowed: bool,
_fd_id: FdId,
bytes: &[u8],
_ecx: &mut MiriInterpCx<'tcx>,
) -> InterpResult<'tcx, io::Result<usize>> {
Expand All @@ -206,9 +201,9 @@ impl FileDescription for NullOutput {
}

fn write<'tcx>(
&mut self,
&self,
_self_ref: &FileDescriptionRef,
_communicate_allowed: bool,
_fd_id: FdId,
bytes: &[u8],
_ecx: &mut MiriInterpCx<'tcx>,
) -> InterpResult<'tcx, io::Result<usize>> {
Expand All @@ -221,26 +216,23 @@ impl FileDescription for NullOutput {
#[derive(Clone, Debug)]
pub struct FileDescWithId<T: FileDescription + ?Sized> {
id: FdId,
file_description: RefCell<Box<T>>,
file_description: Box<T>,
}

#[derive(Clone, Debug)]
pub struct FileDescriptionRef(Rc<FileDescWithId<dyn FileDescription>>);

impl FileDescriptionRef {
fn new(fd: impl FileDescription, id: FdId) -> Self {
FileDescriptionRef(Rc::new(FileDescWithId {
id,
file_description: RefCell::new(Box::new(fd)),
}))
}
impl Deref for FileDescriptionRef {
type Target = dyn FileDescription;

pub fn borrow(&self) -> Ref<'_, dyn FileDescription> {
Ref::map(self.0.file_description.borrow(), |fd| fd.as_ref())
fn deref(&self) -> &Self::Target {
&*self.0.file_description
}
}

pub fn borrow_mut(&self) -> RefMut<'_, dyn FileDescription> {
RefMut::map(self.0.file_description.borrow_mut(), |fd| fd.as_mut())
impl FileDescriptionRef {
fn new(fd: impl FileDescription, id: FdId) -> Self {
FileDescriptionRef(Rc::new(FileDescWithId { id, file_description: Box::new(fd) }))
}

pub fn close<'tcx>(
Expand All @@ -256,7 +248,7 @@ impl FileDescriptionRef {
// Remove entry from the global epoll_event_interest table.
ecx.machine.epoll_interests.remove(id);

RefCell::into_inner(fd.file_description).close(communicate_allowed, ecx)
fd.file_description.close(communicate_allowed, ecx)
}
None => Ok(Ok(())),
}
Expand All @@ -269,16 +261,6 @@ impl FileDescriptionRef {
pub fn get_id(&self) -> FdId {
self.0.id
}

/// Function used to retrieve the readiness events of a file description and insert
/// an `EpollEventInstance` into the ready list if the file description is ready.
pub(crate) fn check_and_update_readiness<'tcx>(
&self,
ecx: &mut InterpCx<'tcx, MiriMachine<'tcx>>,
) -> InterpResult<'tcx, ()> {
use crate::shims::unix::linux::epoll::EvalContextExt;
ecx.check_and_update_readiness(self.get_id(), || self.borrow_mut().get_epoll_ready_events())
}
}

/// Holds a weak reference to the actual file description.
Expand Down Expand Up @@ -334,11 +316,20 @@ impl FdTable {
fds
}

/// Insert a new file description to the FdTable.
pub fn insert_new(&mut self, fd: impl FileDescription) -> i32 {
pub fn new_ref(&mut self, fd: impl FileDescription) -> FileDescriptionRef {
let file_handle = FileDescriptionRef::new(fd, self.next_file_description_id);
self.next_file_description_id = FdId(self.next_file_description_id.0.strict_add(1));
self.insert_ref_with_min_fd(file_handle, 0)
file_handle
}

/// Insert a new file description to the FdTable.
pub fn insert_new(&mut self, fd: impl FileDescription) -> i32 {
let fd_ref = self.new_ref(fd);
self.insert(fd_ref)
}

pub fn insert(&mut self, fd_ref: FileDescriptionRef) -> i32 {
self.insert_ref_with_min_fd(fd_ref, 0)
}

/// Insert a file description, giving it a file descriptor that is at least `min_fd`.
Expand Down Expand Up @@ -368,17 +359,7 @@ impl FdTable {
new_fd
}

pub fn get(&self, fd: i32) -> Option<Ref<'_, dyn FileDescription>> {
let fd = self.fds.get(&fd)?;
Some(fd.borrow())
}

pub fn get_mut(&self, fd: i32) -> Option<RefMut<'_, dyn FileDescription>> {
let fd = self.fds.get(&fd)?;
Some(fd.borrow_mut())
}

pub fn get_ref(&self, fd: i32) -> Option<FileDescriptionRef> {
pub fn get(&self, fd: i32) -> Option<FileDescriptionRef> {
let fd = self.fds.get(&fd)?;
Some(fd.clone())
}
Expand All @@ -397,7 +378,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
fn dup(&mut self, old_fd: i32) -> InterpResult<'tcx, Scalar> {
let this = self.eval_context_mut();

let Some(dup_fd) = this.machine.fds.get_ref(old_fd) else {
let Some(dup_fd) = this.machine.fds.get(old_fd) else {
return Ok(Scalar::from_i32(this.fd_not_found()?));
};
Ok(Scalar::from_i32(this.machine.fds.insert_ref_with_min_fd(dup_fd, 0)))
Expand All @@ -406,7 +387,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
fn dup2(&mut self, old_fd: i32, new_fd: i32) -> InterpResult<'tcx, Scalar> {
let this = self.eval_context_mut();

let Some(dup_fd) = this.machine.fds.get_ref(old_fd) else {
let Some(dup_fd) = this.machine.fds.get(old_fd) else {
return Ok(Scalar::from_i32(this.fd_not_found()?));
};
if new_fd != old_fd {
Expand Down Expand Up @@ -492,7 +473,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
}
let start = this.read_scalar(&args[2])?.to_i32()?;

match this.machine.fds.get_ref(fd) {
match this.machine.fds.get(fd) {
Some(dup_fd) =>
Ok(Scalar::from_i32(this.machine.fds.insert_ref_with_min_fd(dup_fd, start))),
None => Ok(Scalar::from_i32(this.fd_not_found()?)),
Expand Down Expand Up @@ -565,7 +546,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let communicate = this.machine.communicate();

// We temporarily dup the FD to be able to retain mutable access to `this`.
let Some(fd) = this.machine.fds.get_ref(fd) else {
let Some(fd) = this.machine.fds.get(fd) else {
trace!("read: FD not found");
return Ok(Scalar::from_target_isize(this.fd_not_found()?, this));
};
Expand All @@ -576,14 +557,14 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
// `usize::MAX` because it is bounded by the host's `isize`.
let mut bytes = vec![0; usize::try_from(count).unwrap()];
let result = match offset {
None => fd.borrow_mut().read(communicate, fd.get_id(), &mut bytes, this),
None => fd.read(&fd, communicate, &mut bytes, this),
Some(offset) => {
let Ok(offset) = u64::try_from(offset) else {
let einval = this.eval_libc("EINVAL");
this.set_last_error(einval)?;
return Ok(Scalar::from_target_isize(-1, this));
};
fd.borrow_mut().pread(communicate, &mut bytes, offset, this)
fd.pread(communicate, &mut bytes, offset, this)
}
};

Expand Down Expand Up @@ -629,19 +610,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {

let bytes = this.read_bytes_ptr_strip_provenance(buf, Size::from_bytes(count))?.to_owned();
// We temporarily dup the FD to be able to retain mutable access to `this`.
let Some(fd) = this.machine.fds.get_ref(fd) else {
let Some(fd) = this.machine.fds.get(fd) else {
return Ok(Scalar::from_target_isize(this.fd_not_found()?, this));
};

let result = match offset {
None => fd.borrow_mut().write(communicate, fd.get_id(), &bytes, this),
None => fd.write(&fd, communicate, &bytes, this),
Some(offset) => {
let Ok(offset) = u64::try_from(offset) else {
let einval = this.eval_libc("EINVAL");
this.set_last_error(einval)?;
return Ok(Scalar::from_target_isize(-1, this));
};
fd.borrow_mut().pwrite(communicate, &bytes, offset, this)
fd.pwrite(communicate, &bytes, offset, this)
}
};

Expand Down
Loading

0 comments on commit 83f1b38

Please sign in to comment.