Skip to content

Commit

Permalink
fix first segemt & add some example
Browse files Browse the repository at this point in the history
  • Loading branch information
my-vegetable-has-exploded committed Sep 10, 2023
1 parent 6033ac2 commit 0737698
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 76 deletions.
38 changes: 35 additions & 3 deletions examples/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
//!
//! You can try this example by running:
//!
//! cargo run --example server
//! cargo run --example server <server_ip> <port>
//!
//! And then start client in another terminal by running:
//!
//! cargo run --example client
//! cargo run --example client <server_ip> <port>
use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, Rdma, RdmaBuilder};
use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, MrAccess, RCStream, Rdma, RdmaBuilder};
use std::{
alloc::Layout,
env,
Expand Down Expand Up @@ -118,6 +118,35 @@ async fn request_then_write_cas(rdma: &Rdma) -> io::Result<()> {
Ok(())
}

async fn rcstream_send(stream: &mut RCStream) -> io::Result<()> {
for i in 0..10 {
// alloc 8 bytes local memory
let mut lmr = stream.alloc_local_mr(Layout::new::<[u8; 8]>())?;
// write data into lmr
let _num = lmr.as_mut_slice().write(&[i as u8; 8])?;
// send data in mr to the remote end
stream.send_lmr(lmr).await?;
println!("stream send datagram {} ", i);
}
Ok(())
}

async fn rcstream_recv(stream: &mut RCStream) -> io::Result<()> {
for i in 0..10 {
// recieve data from the remote end
let mut lmr_vec = stream.recieve_lmr(8).await?;
println!("stream recieve datagram {}", i);
// check the length of the recieved data
assert!(lmr_vec.len() == 1);
let lmr = lmr_vec.pop().unwrap();
assert!(lmr.length() == 8);
let buff = *(lmr.as_slice());
// check the data
assert_eq!(buff, [i as u8; 8]);
}
Ok(())
}

#[tokio::main]
async fn main() {
println!("client start");
Expand Down Expand Up @@ -153,5 +182,8 @@ async fn main() {
request_then_write_with_imm(&rdma).await.unwrap();
request_then_write_cas(&rdma).await.unwrap();
}
let mut stream: RCStream = rdma.into();
rcstream_send(&mut stream).await.unwrap();
rcstream_recv(&mut stream).await.unwrap();
println!("client done");
}
39 changes: 36 additions & 3 deletions examples/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
//!
//! You can try this example by running:
//!
//! cargo run --example server
//! cargo run --example server <server_ip> <port>
//!
//! And start client in another terminal by running:
//!
//! cargo run --example client
//! cargo run --example client <server_ip> <port>
use async_rdma::{LocalMrReadAccess, Rdma, RdmaBuilder};
use async_rdma::{LocalMrReadAccess, LocalMrWriteAccess, MrAccess, RCStream, Rdma, RdmaBuilder};
use clippy_utilities::Cast;
use std::io::Write;
use std::{alloc::Layout, env, io, process::exit};

/// receive data from client
Expand Down Expand Up @@ -90,6 +91,35 @@ async fn receive_mr_after_being_written_by_cas(rdma: &Rdma) -> io::Result<()> {
Ok(())
}

async fn rcstream_send(stream: &mut RCStream) -> io::Result<()> {
for i in 0..10 {
// alloc 8 bytes local memory
let mut lmr = stream.alloc_local_mr(Layout::new::<[u8; 8]>())?;
// write data into lmr
let _num = lmr.as_mut_slice().write(&[i as u8; 8])?;
// send data in mr to the remote end
stream.send_lmr(lmr).await?;
println!("stream send datagram {} ", i);
}
Ok(())
}

async fn rcstream_recv(stream: &mut RCStream) -> io::Result<()> {
for i in 0..10 {
// recieve data from the remote end
let mut lmr_vec = stream.recieve_lmr(8).await?;
println!("stream recieve datagram {}", i);
// check the length of the recieved data
assert!(lmr_vec.len() == 1);
let lmr = lmr_vec.pop().unwrap();
assert!(lmr.length() == 8);
let buff = *(lmr.as_slice());
// check the data
assert_eq!(buff, [i as u8; 8]);
}
Ok(())
}

#[tokio::main]
async fn main() {
println!("server start");
Expand Down Expand Up @@ -129,5 +159,8 @@ async fn main() {
.unwrap();
receive_mr_after_being_written_by_cas(&rdma).await.unwrap();
}
let mut stream: RCStream = rdma.into();
rcstream_recv(&mut stream).await.unwrap();
rcstream_send(&mut stream).await.unwrap();
println!("server done");
}
15 changes: 3 additions & 12 deletions src/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,6 @@ impl Agent {
Ok(req_submitted)
}

pub(crate) async fn split_to_parts(&self, mut lm: LocalMr) -> Vec<LocalMr> {
let mut lm_len = lm.length();
let mut parts = Vec::new();
while lm_len > 0 {
let end = self.max_msg_len().min(lm_len);
parts.push(lm.split_to(end).unwrap());
lm_len -= end;
}
parts
}

/// Receive content sent from the other side and stored in the `LocalMr`
pub(crate) async fn receive_data(&self) -> io::Result<(LocalMr, Option<u32>)> {
let (lmr, len, imm) = self
Expand Down Expand Up @@ -726,7 +715,7 @@ impl AgentInner {
) -> io::Result<RequestSubmitted<QPSendOwn<LocalMr>>> {
let data_len: usize = data.iter().map(|l| l.length()).sum();
assert!(data_len <= self.max_sr_data_len);
let (tx, mut rx) = channel(2);
let (tx, rx) = channel(2);
let req_id = self
.response_waits
.lock()
Expand Down Expand Up @@ -1037,6 +1026,7 @@ enum Message {
Response(Response),
}

/// Queue pair operation submitted in wq, waitting for wc & response
#[derive(Debug)]
pub(crate) struct RequestSubmitted<Op: QueuePairOp> {
/// the operation of the request
Expand All @@ -1046,6 +1036,7 @@ pub(crate) struct RequestSubmitted<Op: QueuePairOp> {
}

impl<Op: QueuePairOp> RequestSubmitted<Op> {
/// Create a new `RequestSubmitted`
fn new(
inflight: QueuePairOpsInflight<Op>,
rx: Receiver<Result<ResponseKind, io::Error>>,
Expand Down
83 changes: 45 additions & 38 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1675,6 +1675,7 @@ impl Rdma {
/// submit send of the `lm`
///
/// Used with `receive`.
#[allow(unused)]
#[inline]
async fn submit_send(
&self,
Expand All @@ -1696,6 +1697,7 @@ impl Rdma {
lm: Vec<LocalMr>,
imm: u32,
) -> io::Result<RequestSubmitted<QPSendOwn<LocalMr>>> {
debug!("submit send seq_id {:?}", imm);
self.agent
.as_ref()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "Agent is not ready"))?
Expand Down Expand Up @@ -4088,7 +4090,7 @@ impl Rdma {
}
}

/// The wrapper of a RDMA RC connection, with convient methods to write and read.
/// The wrapper of a RDMA RC connection, with convient methods to write and read and order guarantee.
/// TODO how to close stream
#[derive(Debug)]
pub struct RCStream {
Expand All @@ -4107,19 +4109,13 @@ pub struct RCStream {
}

impl RCStream {
/// Create a new RCStream with Rdma
/// Create a new `RCStream` with Rdma
#[must_use]
pub fn new(inner: Rdma) -> Self {
let (inflights_tx, mut inflights_rx) = mpsc::channel(1024);
let _ = tokio::spawn(async move {
loop {
match inflights_rx.recv().await {
Some(inflight) => {
let _ = Self::handle_send_wc(inflight).await;
}
None => {
break;
}
}
while let Some(inflight) = inflights_rx.recv().await {
let _ = Self::handle_send_wc(inflight).await;
}
});
RCStream {
Expand All @@ -4132,15 +4128,7 @@ impl RCStream {
}
}

/// Get the next sequence number, if the current sequence number is u32::MAX, return 0
fn next_seq(seq: u32) -> u32 {
if seq == u32::MAX {
0
} else {
seq + 1
}
}

/// handle the send wc, and check `ResponseKind`
async fn handle_send_wc(inflight: RequestSubmitted<QPSendOwn<LocalMr>>) -> io::Result<()> {
let resp = inflight.response().await?;
if let ResponseKind::SendData(send_data_resp) = resp {
Expand All @@ -4165,26 +4153,30 @@ impl RCStream {
Ok(())
}

/// Send a LocalMr to the remote peer.
/// Send a `LocalMr` whose size is less than `max_message_length` to the remote peer.
pub async fn send_lmr_segment(&mut self, lmr_segment: LocalMr) -> io::Result<()> {
let inflight = self
.inner
.submit_send_with_imm(vec![lmr_segment], self.send_seq)
.await?;
self.send_seq = self.send_seq.wrapping_add(1);
match self.inflights_tx.send(inflight).await {
Ok(_) => Ok(()),
Err(_) => Err(io::Error::new(
io::ErrorKind::Other,
"inflight queue is full",
)),
}
}

/// Send a `LocalMr` to the remote peer.
pub async fn send_lmr(&mut self, mut lmr: LocalMr) -> io::Result<()> {
while lmr.length() > self.inner.clone_attr.agent_attr.max_message_length {
// split to multiple lmr segments
let lmr_segment = lmr.split_to(self.inner.clone_attr.agent_attr.max_message_length);

let inflight = self
.inner
.submit_send_with_imm(vec![lmr_segment.unwrap()], self.send_seq)
.await?;
self.send_seq = Self::next_seq(self.send_seq);
match self.inflights_tx.send(inflight).await {
Ok(_) => {}
Err(_) => {
return Err(io::Error::new(
io::ErrorKind::Other,
"inflight queue is full",
))
}
}
self.send_lmr_segment(lmr_segment.unwrap()).await?;
}
self.send_lmr_segment(lmr).await?;
Ok(())
}

Expand All @@ -4194,11 +4186,13 @@ impl RCStream {
match self.recv_buf.remove(&self.recv_seq) {
Some(lmr) => {
self.read_buf = Some(lmr);
self.recv_seq = Self::next_seq(self.recv_seq);
self.recv_seq = self.recv_seq.wrapping_add(1);
break;
}
None => {
// if next lmr is not in recv_buf, wait for next recv and check again
let (lmr, seq) = self.inner.receive_with_imm().await?;
debug!("recieve seq: {:?}", seq);
match self.recv_buf.insert(seq.unwrap(), lmr) {
Some(_) => {
return Err(io::Error::new(
Expand All @@ -4214,7 +4208,7 @@ impl RCStream {
Ok(())
}

/// recieve LocalMrs whose total size equal to the given size
/// recieve `LocalMrs` whose total size equal to the given size
/// TODO how to read eof?
pub async fn recieve_lmr(&mut self, mut fill_size: usize) -> io::Result<Vec<LocalMr>> {
let mut ret_lmr = vec![];
Expand All @@ -4235,6 +4229,19 @@ impl RCStream {
}
Ok(ret_lmr)
}

/// Allocate a local memory region
/// The parameter `layout` can be obtained by `Layout::new::<Data>()`.
pub fn alloc_local_mr(&mut self, layout: Layout) -> io::Result<LocalMr> {
self.inner.alloc_local_mr(layout)
}
}

impl From<Rdma> for RCStream {
/// Create a `RCStream` from a Rdma
fn from(rdma: Rdma) -> Self {
Self::new(rdma)
}
}

/// Rdma Listener is the wrapper of a `TcpListener`, which is used to
Expand Down
37 changes: 19 additions & 18 deletions src/memory_region/local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,28 +527,29 @@ impl LocalMr {
/// # Examples
///
/// ```
/// #[tokio::test]
/// async fn test_lmr_split() -> io::Result<()> {
/// let rdma = RdmaBuilder::default()
/// .set_port_num(1)
/// .set_gid_index(1)
/// .build()?;
/// let layout = Layout::new::<[u8; 4096]>();
/// let mut lmr = rdma.alloc_local_mr(layout)?;
/// let start_addr = lmr.addr();
/// let lmr_half = lmr.split_to(2048);
/// assert!(lmr_half.is_some());
/// let lmr_half = lmr_half.unwrap();
/// assert_eq!(lmr_half.length(), 2048);
/// assert_eq!(lmr_half.addr(), start_addr);
/// let lmr_overbound = lmr.split_to(2049);
/// assert!(lmr_overbound.is_none());
/// Ok(())
/// }
/// #[tokio::test]
/// async fn test_lmr_split() -> io::Result<()> {
/// let rdma = RdmaBuilder::default()
/// .set_port_num(1)
/// .set_gid_index(1)
/// .build()?;
/// let layout = Layout::new::<[u8; 4096]>();
/// let mut lmr = rdma.alloc_local_mr(layout)?;
/// let start_addr = lmr.addr();
/// let lmr_half = lmr.split_to(2048);
/// assert!(lmr_half.is_some());
/// let lmr_half = lmr_half.unwrap();
/// assert_eq!(lmr_half.length(), 2048);
/// assert_eq!(lmr_half.addr(), start_addr);
/// let lmr_overbound = lmr.split_to(2049);
/// assert!(lmr_overbound.is_none());
/// Ok(())
/// }
/// ```
/// # Panics
///
/// Panics if `at > len`.
#[inline]
pub fn split_to(&mut self, at: usize) -> Option<Self> {
// SAFETY: `self` is checked to be valid and in bounds above.
if at > self.length() {
Expand Down
5 changes: 3 additions & 2 deletions src/queue_pair.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1312,6 +1312,7 @@ where
}

#[derive(Debug)]
/// Queue pair operation submitted in wq, waitting for wc
pub(crate) struct QueuePairOpsInflight<Op: QueuePairOp> {
/// the operation
op: Op,
Expand Down Expand Up @@ -1425,7 +1426,7 @@ impl<Op: QueuePairOp + Unpin> Future for QueuePairOps<Op> {
}
}

/// Queue pair operation wrapper, return after libv_post_send
/// Queue pair operation wrapper, return after `libv_post_send`
#[derive(Debug)]
pub(crate) struct QueuePairOpsSubmit<Op: QueuePairOp + Unpin> {
/// the internal queue pair
Expand All @@ -1437,7 +1438,7 @@ pub(crate) struct QueuePairOpsSubmit<Op: QueuePairOp + Unpin> {
}

impl<Op: QueuePairOp + Unpin> QueuePairOpsSubmit<Op> {
/// Create a new queue QueuePairOpsSubmit wrapper
/// Create a new queue `QueuePairOpsSubmit` wrapper
fn new(qp: Arc<QueuePair>, op: Op, inners: LmrInners) -> Self {
Self {
qp,
Expand Down

0 comments on commit 0737698

Please sign in to comment.