Skip to content

Commit

Permalink
Add support for WebSockets over HTTP/2
Browse files Browse the repository at this point in the history
  • Loading branch information
SabrinaJewson committed Aug 28, 2024
1 parent 1ac617a commit 98030ba
Show file tree
Hide file tree
Showing 13 changed files with 376 additions and 83 deletions.
1 change: 1 addition & 0 deletions axum/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ features = [
[dev-dependencies]
anyhow = "1.0"
axum-macros = { path = "../axum-macros", version = "0.4.1", features = ["__private"] }
hyper = { version = "1.1.0", features = ["client"] }
quickcheck = "1.0"
quickcheck_macros = "1.0"
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] }
Expand Down
204 changes: 153 additions & 51 deletions axum/src/extract/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
//! ```
//! use axum::{
//! extract::ws::{WebSocketUpgrade, WebSocket},
//! routing::get,
//! routing::any,
//! response::{IntoResponse, Response},
//! Router,
//! };
//!
//! let app = Router::new().route("/ws", get(handler));
//! let app = Router::new().route("/ws", any(handler));
//!
//! async fn handler(ws: WebSocketUpgrade) -> Response {
//! ws.on_upgrade(handle_socket)
Expand Down Expand Up @@ -40,7 +40,7 @@
//! use axum::{
//! extract::{ws::{WebSocketUpgrade, WebSocket}, State},
//! response::Response,
//! routing::get,
//! routing::any,
//! Router,
//! };
//!
Expand All @@ -58,7 +58,7 @@
//! }
//!
//! let app = Router::new()
//! .route("/ws", get(handler))
//! .route("/ws", any(handler))
//! .with_state(AppState { /* ... */ });
//! # let _: Router = app;
//! ```
Expand Down Expand Up @@ -102,7 +102,7 @@ use futures_util::{
use http::{
header::{self, HeaderMap, HeaderName, HeaderValue},
request::Parts,
Method, StatusCode,
Method, StatusCode, Version,
};
use hyper_util::rt::TokioIo;
use sha1::{Digest, Sha1};
Expand All @@ -122,17 +122,21 @@ use tokio_tungstenite::{

/// Extractor for establishing WebSocket connections.
///
/// Note: This extractor requires the request method to be `GET` so it should
/// always be used with [`get`](crate::routing::get). Requests with other methods will be
/// rejected.
/// For HTTP/1.1 requests, this extractor requires the request method to be `GET`;
/// in later versions, `CONNECT` is used instead. Thus it should either be used
/// with [`any`](crate::routing::any), or placed behind
/// [`on`](crate::routing::on)`(`[`MethodFilter`]`::GET.or(`[`MethodFilter`]`::POST), ...)`.
///
/// See the [module docs](self) for an example.
///
/// [`MethodFilter`]: crate::routing::MethodFilter
#[cfg_attr(docsrs, doc(cfg(feature = "ws")))]
pub struct WebSocketUpgrade<F = DefaultOnFailedUpgrade> {
config: WebSocketConfig,
/// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response.
protocol: Option<HeaderValue>,
sec_websocket_key: HeaderValue,
/// `None` if HTTP/2+ WebSockets are used.
sec_websocket_key: Option<HeaderValue>,
on_upgrade: hyper::upgrade::OnUpgrade,
on_failed_upgrade: F,
sec_websocket_protocol: Option<HeaderValue>,
Expand Down Expand Up @@ -330,25 +334,34 @@ impl<F> WebSocketUpgrade<F> {
callback(socket).await;
});

#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");

let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(self.sec_websocket_key.as_bytes()),
);

if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}
if let Some(sec_websocket_key) = &self.sec_websocket_key {
// If `sec_websocket_key` was `Some`, we are using HTTP/1.1.

#[allow(clippy::declare_interior_mutable_const)]
const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade");
#[allow(clippy::declare_interior_mutable_const)]
const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket");

let mut builder = Response::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(header::CONNECTION, UPGRADE)
.header(header::UPGRADE, WEBSOCKET)
.header(
header::SEC_WEBSOCKET_ACCEPT,
sign(sec_websocket_key.as_bytes()),
);

if let Some(protocol) = self.protocol {
builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol);
}

builder.body(Body::empty()).unwrap()
builder.body(Body::empty()).unwrap()
} else {
// Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond
// with a 2XX with an empty body:
// <https://datatracker.ietf.org/doc/html/rfc9113#name-the-connect-method>.
Response::new(Body::empty())
}
}
}

Expand Down Expand Up @@ -389,28 +402,46 @@ where
type Rejection = WebSocketUpgradeRejection;

async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}
let sec_websocket_key = if parts.version <= Version::HTTP_11 {
if parts.method != Method::GET {
return Err(MethodNotGet.into());
}

if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}
if !header_contains(&parts.headers, header::CONNECTION, "upgrade") {
return Err(InvalidConnectionHeader.into());
}

if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}
if !header_eq(&parts.headers, header::UPGRADE, "websocket") {
return Err(InvalidUpgradeHeader.into());
}

Some(
parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone(),
)
} else {
if parts.method != Method::CONNECT {
return Err(MethodNotConnect.into());
}

if parts
.extensions
.get::<hyper::ext::Protocol>()
.map_or(true, |p| p.as_str() != "websocket")
{
return Err(InvalidProtocolPseudoheader.into());
}

None
};

if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") {
return Err(InvalidWebSocketVersionHeader.into());
}

let sec_websocket_key = parts
.headers
.get(header::SEC_WEBSOCKET_KEY)
.ok_or(WebSocketKeyHeaderMissing)?
.clone();

let on_upgrade = parts
.extensions
.remove::<hyper::upgrade::OnUpgrade>()
Expand Down Expand Up @@ -708,6 +739,13 @@ pub mod rejection {
pub struct MethodNotGet;
}

define_rejection! {
#[status = METHOD_NOT_ALLOWED]
#[body = "Request method must be `CONNECT`"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct MethodNotConnect;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "Connection header did not include 'upgrade'"]
Expand All @@ -722,6 +760,13 @@ pub mod rejection {
pub struct InvalidUpgradeHeader;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "`:protocol` pseudo-header did not include 'websocket'"]
/// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade).
pub struct InvalidProtocolPseudoheader;
}

define_rejection! {
#[status = BAD_REQUEST]
#[body = "`Sec-WebSocket-Version` header did not include '13'"]
Expand Down Expand Up @@ -757,8 +802,10 @@ pub mod rejection {
/// extractor can fail.
pub enum WebSocketUpgradeRejection {
MethodNotGet,
MethodNotConnect,
InvalidConnectionHeader,
InvalidUpgradeHeader,
InvalidProtocolPseudoheader,
InvalidWebSocketVersionHeader,
WebSocketKeyHeaderMissing,
ConnectionNotUpgradable,
Expand Down Expand Up @@ -833,8 +880,16 @@ mod tests {
use std::future::ready;

use super::*;
use crate::{routing::get, test_helpers::spawn_service, Router};
use crate::{
routing::{any, get},
test_helpers::spawn_service,
Router,
};
use http::{Request, Version};
use http_body_util::BodyExt as _;
use hyper_util::rt::TokioExecutor;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio_tungstenite::tungstenite;
use tower::ServiceExt;

Expand Down Expand Up @@ -883,11 +938,56 @@ mod tests {

#[crate::test]
async fn integration_test() {
let app = Router::new().route(
"/echo",
get(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
);
let addr = spawn_service(echo_app());
let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
.await
.unwrap();
test_echo_app(socket).await;
}

#[crate::test]
#[cfg(feature = "http2")]
async fn http2() {
let addr = spawn_service(echo_app());
let io = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (mut send_request, conn) =
hyper::client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(io)
.await
.unwrap();

// Wait a little for the SETTINGS frame to go through…
for _ in 0..10 {
tokio::task::yield_now().await;
}
assert!(conn.is_extended_connect_protocol_enabled());
tokio::spawn(async {
conn.await.unwrap();
});

let req = Request::builder()
.method(Method::CONNECT)
.extension(hyper::ext::Protocol::from_static("websocket"))
.uri("/echo")
.header("sec-websocket-version", "13")
.header("Host", "server.example.com")
.body(Body::empty())
.unwrap();

let response = send_request.send_request(req).await.unwrap();
let status = response.status();
if status != 200 {
let body = response.into_body().collect().await.unwrap().to_bytes();
let body = std::str::from_utf8(&body).unwrap();
panic!("response status was {}: {body}", status);
}
let upgraded = hyper::upgrade::on(response).await.unwrap();
let upgraded = TokioIo::new(upgraded);
let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await;
test_echo_app(socket).await;
}

fn echo_app() -> Router {
async fn handle_socket(mut socket: WebSocket) {
while let Some(Ok(msg)) = socket.recv().await {
match msg {
Expand All @@ -903,11 +1003,13 @@ mod tests {
}
}

let addr = spawn_service(app);
let (mut socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo"))
.await
.unwrap();
Router::new().route(
"/echo",
any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))),
)
}

async fn test_echo_app<S: AsyncRead + AsyncWrite + Unpin>(mut socket: WebSocketStream<S>) {
let input = tungstenite::Message::Text("foobar".to_owned());
socket.send(input.clone()).await.unwrap();
let output = socket.next().await.unwrap().unwrap();
Expand Down
8 changes: 3 additions & 5 deletions axum/src/routing/method_routing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1034,13 +1034,11 @@ where
match $svc {
MethodEndpoint::None => {}
MethodEndpoint::Route(route) => {
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
return route.clone().oneshot_inner($req);
}
MethodEndpoint::BoxedHandler(handler) => {
let route = handler.clone().into_route(state);
return RouteFuture::from_future(route.clone().oneshot_inner($req))
.strip_body($method == Method::HEAD);
let mut route = handler.clone().into_route(state);
return route.oneshot_inner($req);
}
}
}
Expand Down
6 changes: 2 additions & 4 deletions axum/src/routing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -658,12 +658,10 @@ where

fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture<E> {
match self {
Fallback::Default(route) | Fallback::Service(route) => {
RouteFuture::from_future(route.oneshot_inner(req))
}
Fallback::Default(route) | Fallback::Service(route) => route.oneshot_inner(req),
Fallback::BoxedHandler(handler) => {
let mut route = handler.clone().into_route(state);
RouteFuture::from_future(route.oneshot_inner(req))
route.oneshot_inner(req)
}
}
}
Expand Down
Loading

0 comments on commit 98030ba

Please sign in to comment.