From f4f57d2d01584b75fc422d514a0375985df31ee1 Mon Sep 17 00:00:00 2001 From: Alex Rudy Date: Wed, 3 Apr 2024 03:24:39 +0000 Subject: [PATCH] Use socket2 to support SO_NODELAY and SO_KEEPALIVE on incoming connections --- tonic/Cargo.toml | 4 +++- tonic/src/transport/server/incoming.rs | 33 ++++++++++++++++++++++++-- tonic/src/transport/server/mod.rs | 21 ++++++++-------- 3 files changed, 44 insertions(+), 14 deletions(-) diff --git a/tonic/Cargo.toml b/tonic/Cargo.toml index d19f0a30b..112eb0b9c 100644 --- a/tonic/Cargo.toml +++ b/tonic/Cargo.toml @@ -48,6 +48,7 @@ transport = [ "dep:hyper-util", "tokio/net", "tokio/time", + "dep:socket2", "dep:tower", "dep:hyper-timeout", ] @@ -80,10 +81,12 @@ prost = { version = "0.12", default-features = false, features = [ async-trait = { version = "0.1.13", optional = true } # transport +axum = { version = "0.7", default_features = false, optional = true } h2 = { version = "0.4", optional = true } hyper = { version = "1.0", features = ["full"], optional = true } hyper-util = { version = "0.1", features = ["full"], optional = true } hyper-timeout = { version = "0.5", optional = true } +socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] } tokio-stream = { version = "0.1", features = ["net"] } tower = { version = "0.4.7", default-features = false, features = [ "balance", @@ -95,7 +98,6 @@ tower = { version = "0.4.7", default-features = false, features = [ "timeout", "util", ], optional = true } -axum = { version = "0.7", default_features = false, optional = true } # rustls async-stream = { version = "0.3", optional = true } diff --git a/tonic/src/transport/server/incoming.rs b/tonic/src/transport/server/incoming.rs index 21ff4ffbb..eb9f55595 100644 --- a/tonic/src/transport/server/incoming.rs +++ b/tonic/src/transport/server/incoming.rs @@ -3,7 +3,7 @@ use crate::transport::service::ServerIo; use std::{ net::{SocketAddr, TcpListener as StdTcpListener}, pin::Pin, - task::{Context, Poll}, + task::{ready, Context, Poll}, time::Duration, }; use tokio::{ @@ -12,6 +12,7 @@ use tokio::{ }; use tokio_stream::wrappers::TcpListenerStream; use tokio_stream::{Stream, StreamExt}; +use tracing::warn; #[cfg(not(feature = "tls"))] pub(crate) fn tcp_incoming( @@ -195,7 +196,35 @@ impl Stream for TcpIncoming { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - Pin::new(&mut self.inner).poll_next(cx) + match ready!(Pin::new(&mut self.inner).poll_next(cx)) { + Some(Ok(stream)) => { + set_accept_socketoptions(&stream, self.tcp_nodelay, self.tcp_keepalive_timeout); + Some(Ok(stream)).into() + } + other => Poll::Ready(other), + } + } +} + +// Consistent with hyper-0.14, this function does not return an error. +fn set_accept_socketoptions( + stream: &TcpStream, + tcp_nodelay: bool, + tcp_keepalive_timeout: Option, +) { + if tcp_nodelay { + if let Err(e) = stream.set_nodelay(true) { + warn!("error trying to set TCP nodelay: {}", e); + } + } + + if let Some(timeout) = tcp_keepalive_timeout { + let sock_ref = socket2::SockRef::from(&stream); + let sock_keepalive = socket2::TcpKeepalive::new().with_time(timeout); + + if let Err(e) = sock_ref.set_tcp_keepalive(&sock_keepalive) { + warn!("error trying to set TCP keepalive: {}", e); + } } } diff --git a/tonic/src/transport/server/mod.rs b/tonic/src/transport/server/mod.rs index 05df57926..05652fe46 100644 --- a/tonic/src/transport/server/mod.rs +++ b/tonic/src/transport/server/mod.rs @@ -523,8 +523,8 @@ impl Server { let timeout = self.timeout; let max_frame_size = self.max_frame_size; - // FIXME: this requires additonal implementation here. - let http2_only = !self.accept_http1; + // TODO: Reqiures support from hyper-util + let _http2_only = !self.accept_http1; let http2_keepalive_interval = self.http2_keepalive_interval; let http2_keepalive_timeout = self @@ -532,7 +532,8 @@ impl Server { .unwrap_or_else(|| Duration::new(DEFAULT_HTTP2_KEEPALIVE_TIMEOUT_SECS, 0)); let http2_adaptive_window = self.http2_adaptive_window; - let http2_max_pending_accept_reset_streams = self.http2_max_pending_accept_reset_streams; + // TODO: Requires a new release of hyper and hyper-util + let _http2_max_pending_accept_reset_streams = self.http2_max_pending_accept_reset_streams; let make_service = self.service_builder.service(svc); @@ -547,6 +548,9 @@ impl Server { let mut builder = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()); + //TODO: Set http2-only when available in hyper_util + //builder.http2_only(http2_only); + builder .http2() .initial_connection_window_size(init_connection_window_size) @@ -555,8 +559,8 @@ impl Server { .keep_alive_interval(http2_keepalive_interval) .keep_alive_timeout(http2_keepalive_timeout) .adaptive_window(http2_adaptive_window.unwrap_or_default()) - // FIXME: wait for this to be added to hyper-util - // .max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) + // TODO: wait for this to be added to hyper-util + //.max_pending_accept_reset_streams(http2_max_pending_accept_reset_streams) .max_frame_size(max_frame_size); let (signal_tx, signal_rx) = tokio::sync::watch::channel(()); @@ -1068,18 +1072,13 @@ where } } +// A future which only yields `Poll::Ready` once, and thereafter yields `Poll::Pending`. #[pin_project] struct Fuse { #[pin] inner: Option, } -impl Fuse { - fn is_terminated(self: &Pin<&mut Self>) -> bool { - self.inner.is_none() - } -} - impl Future for Fuse where F: Future,