Skip to content

Commit

Permalink
Fix recvmmsg(2) implementation
Browse files Browse the repository at this point in the history
There were two problems discovered with the `recvmmsg(2)` implementation
that this changeset attempts to fix:

1. As mentioned in /issues/1325, `recvmmsg(2)` can return fewer
   messages than requested, and
2. Passing the return value of `recvmmsg(2)` as the number of bytes in
   the messages received is incorrect.

This changeset incorporates the proposed fix from /issues/1325,
as well as passing the correct value (`mmsghdr.msg_len`) for the number
of bytes in a given message.
  • Loading branch information
codeslinger committed Nov 28, 2020
1 parent 1794a47 commit aef3068
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/sys/socket/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1217,17 +1217,18 @@ pub fn recvmmsg<'a, I>(

let ret = unsafe { libc::recvmmsg(fd, output.as_mut_ptr(), output.len() as _, flags.bits() as _, timeout) };

let r = Errno::result(ret)?;
let _ = Errno::result(ret)?;

Ok(output
.into_iter()
.take(ret as usize)
.zip(addresses.iter().map(|addr| unsafe{addr.assume_init()}))
.zip(results.into_iter())
.map(|((mmsghdr, address), (msg_controllen, cmsg_buffer))| {
unsafe {
read_mhdr(
mmsghdr.msg_hdr,
r as isize,
mmsghdr.msg_len as isize,
msg_controllen,
address,
cmsg_buffer
Expand Down
70 changes: 70 additions & 0 deletions test/sys/test_socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,76 @@ mod recvfrom {

send_thread.join().unwrap();
}

#[cfg(any(
target_os = "linux",
target_os = "android",
target_os = "freebsd",
target_os = "netbsd",
))]
#[test]
pub fn udp_recvmmsg_dontwait_short_read() {
use nix::sys::uio::IoVec;
use nix::sys::socket::{MsgFlags, recvmmsg};

const NUM_MESSAGES_SENT: usize = 2;
const DATA: [u8; 4] = [1,2,3,4];

let std_sa = SocketAddr::from_str("127.0.0.1:6799").unwrap();
let inet_addr = InetAddr::from_std(&std_sa);
let sock_addr = SockAddr::new_inet(inet_addr);

let rsock = socket(AddressFamily::Inet,
SockType::Datagram,
SockFlag::empty(),
None
).unwrap();
bind(rsock, &sock_addr).unwrap();
let ssock = socket(
AddressFamily::Inet,
SockType::Datagram,
SockFlag::empty(),
None,
).expect("send socket failed");

let send_thread = thread::spawn(move || {
for _ in 0..NUM_MESSAGES_SENT {
sendto(ssock, &DATA[..], &sock_addr, MsgFlags::empty()).unwrap();
}
});
// Ensure we've sent all the messages before continuing so `recvmmsg`
// will return right away
send_thread.join().unwrap();

let mut msgs = std::collections::LinkedList::new();

// Buffers to receive >`NUM_MESSAGES_SENT` messages to ensure `recvmmsg`
// will return when there are fewer than requested messages in the
// kernel buffers when using `MSG_DONTWAIT`.
let mut receive_buffers = [[0u8; 32]; NUM_MESSAGES_SENT + 2];
let iovs: Vec<_> = receive_buffers.iter_mut().map(|buf| {
[IoVec::from_mut_slice(&mut buf[..])]
}).collect();

for iov in &iovs {
msgs.push_back(RecvMmsgData {
iov: iov,
cmsg_buffer: None,
})
};

let res = recvmmsg(rsock, &mut msgs, MsgFlags::MSG_DONTWAIT, None).expect("recvmmsg");
assert_eq!(res.len(), NUM_MESSAGES_SENT);

for RecvMsg { address, bytes, .. } in res.into_iter() {
assert_eq!(AddressFamily::Inet, address.unwrap().family());
assert_eq!(DATA.len(), bytes);
}

for buf in &receive_buffers[..NUM_MESSAGES_SENT] {
assert_eq!(&buf[..DATA.len()], DATA);
}
}
}

// Test error handling of our recvmsg wrapper
Expand Down

0 comments on commit aef3068

Please sign in to comment.