Skip to content

Commit

Permalink
breaking: update to tokio-tungstenite 0.26 and update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
adryzz committed Dec 18, 2024
1 parent 67fa445 commit 67c2ffc
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 40 deletions.
4 changes: 2 additions & 2 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ serde_path_to_error = { version = "0.1.8", optional = true }
serde_urlencoded = { version = "0.7", optional = true }
sha1 = { version = "0.10", optional = true }
tokio = { package = "tokio", version = "1.25.0", features = ["time"], optional = true }
tokio-tungstenite = { version = "0.25.0", optional = true }
tokio-tungstenite = { version = "0.26.0", optional = true }
tracing = { version = "0.1", default-features = false, optional = true }

[dependencies.tower-http]
Expand Down Expand Up @@ -127,7 +127,7 @@ serde_json = { version = "1.0", features = ["raw_value"] }
time = { version = "0.3", features = ["serde-human-readable"] }
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
tokio-stream = "0.1"
tokio-tungstenite = "0.25.0"
tokio-tungstenite = "0.26.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["json"] }
uuid = { version = "1.0", features = ["serde", "v4"] }
Expand Down
174 changes: 146 additions & 28 deletions axum/src/extract/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,7 @@ use std::{
use tokio_tungstenite::{
tungstenite::{
self as ts,
protocol::{
self,
frame::{Payload, Utf8Payload},
WebSocketConfig,
},
protocol::{self, WebSocketConfig},
},
WebSocketStream,
};
Expand Down Expand Up @@ -557,16 +553,133 @@ impl Sink<Message> for WebSocket {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
/// UTF-8 wrapper for [Bytes].
///
/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8.
pub struct Utf8Bytes(ts::Utf8Bytes);

impl Utf8Bytes {
/// Creates from a static str.
#[inline]
pub const fn from_static(str: &'static str) -> Self {
Self(ts::Utf8Bytes::from_static(str))
}

/// Returns as a string slice.
#[inline]
pub fn as_str(&self) -> &str {
// SAFETY: is valid uft8
self.0.as_str()
}

fn into_tungstenite(self) -> ts::Utf8Bytes {
self.0
}
}

impl std::ops::Deref for Utf8Bytes {
type Target = str;

/// ```
/// /// Example fn that takes a str slice
/// fn a(s: &str) {}
///
/// let data = axum::extract::ws::Utf8Bytes::from_static("foo123");
///
/// // auto-deref as arg
/// a(&data);
///
/// // deref to str methods
/// assert_eq!(data.len(), 6);
/// ```
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}

impl std::fmt::Display for Utf8Bytes {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

impl TryFrom<Bytes> for Utf8Bytes {
type Error = std::str::Utf8Error;

#[inline]
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
std::str::from_utf8(&bytes)?;
Ok(Self(bytes.try_into()?))
}
}

impl TryFrom<Vec<u8>> for Utf8Bytes {
type Error = std::str::Utf8Error;

#[inline]
fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(v.try_into()?))
}
}

impl From<String> for Utf8Bytes {
#[inline]
fn from(s: String) -> Self {
Self(s.into())
}
}

impl From<&str> for Utf8Bytes {
#[inline]
fn from(s: &str) -> Self {
Self(s.into())
}
}

impl From<&String> for Utf8Bytes {
#[inline]
fn from(s: &String) -> Self {
Self(s.into())
}
}

impl From<Utf8Bytes> for Bytes {
#[inline]
fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
bytes.into()
}
}

impl<T> PartialEq<T> for Utf8Bytes
where
for<'a> &'a str: PartialEq<T>,
{
/// ```
/// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123");
/// assert_eq!(payload, "foo123");
/// assert_eq!(payload, "foo123".to_string());
/// assert_eq!(payload, &"foo123".to_string());
/// assert_eq!(payload, std::borrow::Cow::from("foo123"));
/// ```
#[inline]
fn eq(&self, other: &T) -> bool {
self.as_str() == *other
}
}

/// Status code used to indicate why an endpoint is closing the WebSocket connection.
pub type CloseCode = u16;

/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CloseFrame<'t> {
pub struct CloseFrame {
/// The reason as a code.
pub code: CloseCode,
/// The reason as text string.
pub reason: Cow<'t, str>,
pub reason: Utf8Bytes,
}

/// A WebSocket message.
Expand Down Expand Up @@ -595,24 +708,24 @@ pub struct CloseFrame<'t> {
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message {
/// A text WebSocket message
Text(Utf8Payload),
Text(Utf8Bytes),
/// A binary WebSocket message
Binary(Payload),
Binary(Bytes),
/// A ping message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Ping messages will be automatically responded to by the server, so you do not have to worry
/// about dealing with them yourself.
Ping(Payload),
Ping(Bytes),
/// A pong message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Pong messages will be automatically sent to the client if a ping message is received, so
/// you do not have to worry about constructing them yourself unless you want to implement a
/// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
Pong(Payload),
Pong(Bytes),
/// A close message with the optional close frame.
///
/// You may "uncleanly" close a WebSocket connection at any time
Expand All @@ -632,33 +745,33 @@ pub enum Message {
/// Since no further messages will be received,
/// you may either do nothing
/// or explicitly drop the connection.
Close(Option<CloseFrame<'static>>),
Close(Option<CloseFrame>),
}

impl Message {
fn into_tungstenite(self) -> ts::Message {
match self {
Self::Text(text) => ts::Message::Text(text),
Self::Text(text) => ts::Message::Text(text.into_tungstenite()),
Self::Binary(binary) => ts::Message::Binary(binary),
Self::Ping(ping) => ts::Message::Ping(ping),
Self::Pong(pong) => ts::Message::Pong(pong),
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
code: ts::protocol::frame::coding::CloseCode::from(close.code),
reason: close.reason,
reason: close.reason.into_tungstenite(),
})),
Self::Close(None) => ts::Message::Close(None),
}
}

fn from_tungstenite(message: ts::Message) -> Option<Self> {
match message {
ts::Message::Text(text) => Some(Self::Text(text)),
ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))),
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
code: close.code.into(),
reason: close.reason,
reason: Utf8Bytes(close.reason),
}))),
ts::Message::Close(None) => Some(Self::Close(None)),
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
Expand All @@ -668,12 +781,12 @@ impl Message {
}

/// Consume the WebSocket and return it as binary data.
pub fn into_data(self) -> Vec<u8> {
pub fn into_data(self) -> Bytes {
match self {
Self::Text(string) => string.to_string().into_bytes(),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data.as_slice().to_vec(),
Self::Close(None) => Vec::new(),
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
Self::Text(string) => string.into(),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
Self::Close(None) => Bytes::new(),
Self::Close(Some(frame)) => frame.reason.into(),
}
}

Expand All @@ -682,12 +795,12 @@ impl Message {
match self {
Self::Text(string) => Ok(string.to_string()),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
Ok(String::from_utf8(data.as_slice().to_vec())
Ok(String::from_utf8(data.to_vec())
.map_err(|err| err.utf8_error())
.map_err(Error::new)?)
}
Self::Close(None) => Ok(String::new()),
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
Self::Close(Some(frame)) => Ok(frame.reason.to_string()),
}
}

Expand All @@ -697,7 +810,7 @@ impl Message {
match *self {
Self::Text(ref string) => Ok(string.as_str()),
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
Ok(std::str::from_utf8(data.as_slice()).map_err(Error::new)?)
Ok(std::str::from_utf8(data).map_err(Error::new)?)
}
Self::Close(None) => Ok(""),
Self::Close(Some(ref frame)) => Ok(&frame.reason),
Expand All @@ -719,7 +832,7 @@ impl<'s> From<&'s str> for Message {

impl<'b> From<&'b [u8]> for Message {
fn from(data: &'b [u8]) -> Self {
Message::Binary(data.into())
Message::Binary(Bytes::copy_from_slice(data))
}
}

Expand All @@ -731,7 +844,7 @@ impl From<Vec<u8>> for Message {

impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_data()
msg.into_data().to_vec()
}
}

Expand Down Expand Up @@ -1038,10 +1151,15 @@ mod tests {
assert_eq!(input, output);

socket
.send(tungstenite::Message::Ping("ping".as_bytes().into()))
.send(tungstenite::Message::Ping(
Bytes::from_static(b"ping").into(),
))
.await
.unwrap();
let output = socket.next().await.unwrap().unwrap();
assert_eq!(output, tungstenite::Message::Pong("ping".as_bytes().into()));
assert_eq!(
output,
tungstenite::Message::Pong(Bytes::from_static(b"ping").into())
);
}
}
2 changes: 1 addition & 1 deletion examples/websockets/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ futures = "0.3"
futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] }
headers = "0.4"
tokio = { version = "1.0", features = ["full"] }
tokio-tungstenite = "0.25.0"
tokio-tungstenite = "0.26.0"
tower-http = { version = "0.6.1", features = ["fs", "trace"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
Expand Down
9 changes: 4 additions & 5 deletions examples/websockets/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
use futures_util::stream::FuturesUnordered;
use futures_util::{SinkExt, StreamExt};
use std::borrow::Cow;
use std::ops::ControlFlow;
use std::time::Instant;
use tokio_tungstenite::tungstenite::protocol::frame::Payload;
use tokio_tungstenite::tungstenite::Utf8Bytes;

// we will use tungstenite for websocket client impl (same library as what axum is using)
use tokio_tungstenite::{
Expand Down Expand Up @@ -66,8 +65,8 @@ async fn spawn_client(who: usize) {

//we can ping the server for start
sender
.send(Message::Ping(Payload::Shared(
"Hello, Server!".as_bytes().into(),
.send(Message::Ping(axum::body::Bytes::from_static(
b"Hello, Server!",
)))
.await
.expect("Can not send!");
Expand All @@ -93,7 +92,7 @@ async fn spawn_client(who: usize) {
if let Err(e) = sender
.send(Message::Close(Some(CloseFrame {
code: CloseCode::Normal,
reason: Cow::from("Goodbye"),
reason: Utf8Bytes::from_static("Goodbye"),
})))
.await
{
Expand Down
8 changes: 4 additions & 4 deletions examples/websockets/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
//! ```
use axum::{
extract::ws::{Message, WebSocket, WebSocketUpgrade},
body::Bytes,
extract::ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
response::IntoResponse,
routing::any,
Router,
};
use axum_extra::TypedHeader;

use std::borrow::Cow;
use std::ops::ControlFlow;
use std::{net::SocketAddr, path::PathBuf};
use tower_http::{
Expand Down Expand Up @@ -102,7 +102,7 @@ async fn ws_handler(
async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
// send a ping (unsupported by some browsers) just to kick things off and get a response
if socket
.send(Message::Ping(vec![1, 2, 3].into()))
.send(Message::Ping(Bytes::from_static(&[1, 2, 3])))
.await
.is_ok()
{
Expand Down Expand Up @@ -169,7 +169,7 @@ async fn handle_socket(mut socket: WebSocket, who: SocketAddr) {
if let Err(e) = sender
.send(Message::Close(Some(CloseFrame {
code: axum::extract::ws::close_code::NORMAL,
reason: Cow::from("Goodbye"),
reason: Utf8Bytes::from_static("Goodbye"),
})))
.await
{
Expand Down

0 comments on commit 67c2ffc

Please sign in to comment.