Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

breaking: update tokio-tungstenite to 0.26 and update examples #3078

Merged
merged 7 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ serde_path_to_error = { version = "0.1.8", optional = true }
serde_urlencoded = { version = "0.7", optional = true }
sha1 = { version = "0.10", optional = true }
tokio = { package = "tokio", version = "1.25.0", features = ["time"], optional = true }
tokio-tungstenite = { version = "0.24.0", optional = true }
tokio-tungstenite = { version = "0.26.0", optional = true }
tracing = { version = "0.1", default-features = false, optional = true }

[dependencies.tower-http]
Expand Down Expand Up @@ -127,7 +127,7 @@ serde_json = { version = "1.0", features = ["raw_value"] }
time = { version = "0.3", features = ["serde-human-readable"] }
tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] }
tokio-stream = "0.1"
tokio-tungstenite = "0.24.0"
tokio-tungstenite = "0.26.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["json"] }
uuid = { version = "1.0", features = ["serde", "v4"] }
Expand Down
191 changes: 161 additions & 30 deletions axum/src/extract/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,16 +553,131 @@ impl Sink<Message> for WebSocket {
}
}

/// UTF-8 wrapper for [Bytes].
///
/// An [Utf8Bytes] is always guaranteed to contain valid UTF-8.
#[derive(Debug, Clone, PartialEq, Eq, Default)]
pub struct Utf8Bytes(ts::Utf8Bytes);

impl Utf8Bytes {
/// Creates from a static str.
#[inline]
pub const fn from_static(str: &'static str) -> Self {
Self(ts::Utf8Bytes::from_static(str))
}

/// Returns as a string slice.
#[inline]
pub fn as_str(&self) -> &str {
self.0.as_str()
}

fn into_tungstenite(self) -> ts::Utf8Bytes {
self.0
}
}

impl std::ops::Deref for Utf8Bytes {
type Target = str;

/// ```
/// /// Example fn that takes a str slice
/// fn a(s: &str) {}
///
/// let data = axum::extract::ws::Utf8Bytes::from_static("foo123");
///
/// // auto-deref as arg
/// a(&data);
///
/// // deref to str methods
/// assert_eq!(data.len(), 6);
/// ```
#[inline]
fn deref(&self) -> &Self::Target {
self.as_str()
}
}

impl std::fmt::Display for Utf8Bytes {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}

impl TryFrom<Bytes> for Utf8Bytes {
type Error = std::str::Utf8Error;

#[inline]
fn try_from(bytes: Bytes) -> Result<Self, Self::Error> {
Ok(Self(bytes.try_into()?))
}
}

impl TryFrom<Vec<u8>> for Utf8Bytes {
type Error = std::str::Utf8Error;

#[inline]
fn try_from(v: Vec<u8>) -> Result<Self, Self::Error> {
Ok(Self(v.try_into()?))
}
}

impl From<String> for Utf8Bytes {
#[inline]
fn from(s: String) -> Self {
Self(s.into())
}
}

impl From<&str> for Utf8Bytes {
#[inline]
fn from(s: &str) -> Self {
Self(s.into())
}
}

impl From<&String> for Utf8Bytes {
#[inline]
fn from(s: &String) -> Self {
Self(s.into())
}
}

impl From<Utf8Bytes> for Bytes {
#[inline]
fn from(Utf8Bytes(bytes): Utf8Bytes) -> Self {
bytes.into()
}
}

impl<T> PartialEq<T> for Utf8Bytes
where
for<'a> &'a str: PartialEq<T>,
Comment on lines +654 to +656
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: revisit.

{
/// ```
/// let payload = axum::extract::ws::Utf8Bytes::from_static("foo123");
/// assert_eq!(payload, "foo123");
/// assert_eq!(payload, "foo123".to_string());
/// assert_eq!(payload, &"foo123".to_string());
/// assert_eq!(payload, std::borrow::Cow::from("foo123"));
/// ```
#[inline]
fn eq(&self, other: &T) -> bool {
self.as_str() == *other
}
}

/// Status code used to indicate why an endpoint is closing the WebSocket connection.
pub type CloseCode = u16;

/// A struct representing the close command.
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct CloseFrame<'t> {
pub struct CloseFrame {
/// The reason as a code.
pub code: CloseCode,
/// The reason as text string.
pub reason: Cow<'t, str>,
pub reason: Utf8Bytes,
}

/// A WebSocket message.
Expand Down Expand Up @@ -591,24 +706,24 @@ pub struct CloseFrame<'t> {
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Message {
/// A text WebSocket message
Text(String),
Text(Utf8Bytes),
/// A binary WebSocket message
Binary(Vec<u8>),
Binary(Bytes),
/// A ping message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Ping messages will be automatically responded to by the server, so you do not have to worry
/// about dealing with them yourself.
Ping(Vec<u8>),
Ping(Bytes),
/// A pong message with the specified payload
///
/// The payload here must have a length less than 125 bytes.
///
/// Pong messages will be automatically sent to the client if a ping message is received, so
/// you do not have to worry about constructing them yourself unless you want to implement a
/// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3).
Pong(Vec<u8>),
Pong(Bytes),
/// A close message with the optional close frame.
///
/// You may "uncleanly" close a WebSocket connection at any time
Expand All @@ -628,33 +743,33 @@ pub enum Message {
/// Since no further messages will be received,
/// you may either do nothing
/// or explicitly drop the connection.
Close(Option<CloseFrame<'static>>),
Close(Option<CloseFrame>),
}

impl Message {
fn into_tungstenite(self) -> ts::Message {
match self {
Self::Text(text) => ts::Message::Text(text),
Self::Text(text) => ts::Message::Text(text.into_tungstenite()),
Self::Binary(binary) => ts::Message::Binary(binary),
Self::Ping(ping) => ts::Message::Ping(ping),
Self::Pong(pong) => ts::Message::Pong(pong),
Self::Close(Some(close)) => ts::Message::Close(Some(ts::protocol::CloseFrame {
code: ts::protocol::frame::coding::CloseCode::from(close.code),
reason: close.reason,
reason: close.reason.into_tungstenite(),
})),
Self::Close(None) => ts::Message::Close(None),
}
}

fn from_tungstenite(message: ts::Message) -> Option<Self> {
match message {
ts::Message::Text(text) => Some(Self::Text(text)),
ts::Message::Text(text) => Some(Self::Text(Utf8Bytes(text))),
ts::Message::Binary(binary) => Some(Self::Binary(binary)),
ts::Message::Ping(ping) => Some(Self::Ping(ping)),
ts::Message::Pong(pong) => Some(Self::Pong(pong)),
ts::Message::Close(Some(close)) => Some(Self::Close(Some(CloseFrame {
code: close.code.into(),
reason: close.reason,
reason: Utf8Bytes(close.reason),
}))),
ts::Message::Close(None) => Some(Self::Close(None)),
// we can ignore `Frame` frames as recommended by the tungstenite maintainers
Expand All @@ -664,44 +779,60 @@ impl Message {
}

/// Consume the WebSocket and return it as binary data.
pub fn into_data(self) -> Vec<u8> {
pub fn into_data(self) -> Bytes {
match self {
Self::Text(string) => string.into_bytes(),
Self::Text(string) => Bytes::from(string),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => data,
Self::Close(None) => Vec::new(),
Self::Close(Some(frame)) => frame.reason.into_owned().into_bytes(),
Self::Close(None) => Bytes::new(),
Self::Close(Some(frame)) => Bytes::from(frame.reason),
}
}

/// Attempt to consume the WebSocket message and convert it to a String.
pub fn into_text(self) -> Result<String, Error> {
/// Attempt to consume the WebSocket message and convert it to a Utf8Bytes.
pub fn into_text(self) -> Result<Utf8Bytes, Error> {
match self {
Self::Text(string) => Ok(string),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => Ok(String::from_utf8(data)
.map_err(|err| err.utf8_error())
.map_err(Error::new)?),
Self::Close(None) => Ok(String::new()),
Self::Close(Some(frame)) => Ok(frame.reason.into_owned()),
Self::Binary(data) | Self::Ping(data) | Self::Pong(data) => {
Ok(Utf8Bytes::try_from(data).map_err(Error::new)?)
}
Self::Close(None) => Ok(Utf8Bytes::default()),
Self::Close(Some(frame)) => Ok(frame.reason),
}
}

/// Attempt to get a &str from the WebSocket message,
/// this will try to convert binary data to utf8.
pub fn to_text(&self) -> Result<&str, Error> {
match *self {
Self::Text(ref string) => Ok(string),
Self::Text(ref string) => Ok(string.as_str()),
Self::Binary(ref data) | Self::Ping(ref data) | Self::Pong(ref data) => {
Ok(std::str::from_utf8(data).map_err(Error::new)?)
}
Self::Close(None) => Ok(""),
Self::Close(Some(ref frame)) => Ok(&frame.reason),
}
}

/// Create a new text WebSocket message from a stringable.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to word this, but "stringable" is not a word in my vocabulary :D

pub fn text<S>(string: S) -> Message
where
S: Into<Utf8Bytes>,
{
Message::Text(string.into())
}

/// Create a new binary WebSocket message by converting to `Bytes`.
pub fn binary<B>(bin: B) -> Message
where
B: Into<Bytes>,
{
Message::Binary(bin.into())
}
}

impl From<String> for Message {
fn from(string: String) -> Self {
Message::Text(string)
Message::Text(string.into())
}
}

Expand All @@ -713,19 +844,19 @@ impl<'s> From<&'s str> for Message {

impl<'b> From<&'b [u8]> for Message {
fn from(data: &'b [u8]) -> Self {
Message::Binary(data.into())
Message::Binary(Bytes::copy_from_slice(data))
}
}

impl From<Vec<u8>> for Message {
fn from(data: Vec<u8>) -> Self {
Message::Binary(data)
Message::Binary(data.into())
}
}

impl From<Message> for Vec<u8> {
fn from(msg: Message) -> Self {
msg.into_data()
msg.into_data().to_vec()
}
}

Expand Down Expand Up @@ -1026,19 +1157,19 @@ mod tests {
}

async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
let input = tungstenite::Message::Text("foobar".to_owned());
let input = tungstenite::Message::Text(tungstenite::Utf8Bytes::from_static("foobar"));
socket.send(input.clone()).await.unwrap();
let output = socket.next().await.unwrap().unwrap();
assert_eq!(input, output);

socket
.send(tungstenite::Message::Ping("ping".to_owned().into_bytes()))
.send(tungstenite::Message::Ping(Bytes::from_static(b"ping")))
.await
.unwrap();
let output = socket.next().await.unwrap().unwrap();
assert_eq!(
output,
tungstenite::Message::Pong("ping".to_owned().into_bytes())
tungstenite::Message::Pong(Bytes::from_static(b"ping"))
);
}
}
10 changes: 6 additions & 4 deletions examples/chat/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

use axum::{
extract::{
ws::{Message, WebSocket, WebSocketUpgrade},
ws::{Message, Utf8Bytes, WebSocket, WebSocketUpgrade},
State,
},
response::{Html, IntoResponse},
Expand Down Expand Up @@ -79,15 +79,17 @@ async fn websocket(stream: WebSocket, state: Arc<AppState>) {
while let Some(Ok(message)) = receiver.next().await {
if let Message::Text(name) = message {
// If username that is sent by client is not taken, fill username string.
check_username(&state, &mut username, &name);
check_username(&state, &mut username, name.as_str());

// If not empty we want to quit the loop else we want to quit function.
if !username.is_empty() {
break;
} else {
// Only send our client that username is taken.
let _ = sender
.send(Message::Text(String::from("Username already taken.")))
.send(Message::Text(Utf8Bytes::from_static(
"Username already taken.",
)))
.await;

return;
Expand All @@ -109,7 +111,7 @@ async fn websocket(stream: WebSocket, state: Arc<AppState>) {
let mut send_task = tokio::spawn(async move {
while let Ok(msg) = rx.recv().await {
// In any websocket error, break loop.
if sender.send(Message::Text(msg)).await.is_err() {
if sender.send(Message::text(msg)).await.is_err() {
break;
}
}
Expand Down
2 changes: 1 addition & 1 deletion examples/testing-websockets/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ publish = false
axum = { path = "../../axum", features = ["ws"] }
futures = "0.3"
tokio = { version = "1.0", features = ["full"] }
tokio-tungstenite = "0.24"
tokio-tungstenite = "0.26"
Loading
Loading