Skip to content

Commit

Permalink
feat!: Yield remote address upon accepting a connection, and include …
Browse files Browse the repository at this point in the history
…it in errors.

BREAKING CHANGE: The enum variant `Error::ListenerError` is now struct-like instead of tuple-like, and is `non_exhaustive` like the enum itself.

BREAKING CHANGE: `Error` now has three type parameters, not two.

BREAKING CHANGE: `TlsListener::accept` and `<TlsListener as Stream>::next` yields a tuple of (connection, remote address), not just the connection.

BREAKING CHANGE: `AsyncAccept` now has an associated type `Address`, which `poll_accept` must now return along with the accepted connection.
  • Loading branch information
ahcodedthat authored and tmccombs committed Oct 17, 2023
1 parent 6bd78b6 commit e920fbb
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 50 deletions.
6 changes: 3 additions & 3 deletions examples/echo-threads.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mod tls_config;
use tls_config::tls_acceptor;

#[inline]
async fn handle_stream(stream: TlsStream<TcpStream>) {
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
let (mut reader, mut writer) = split(stream);
match copy(&mut reader, &mut writer).await {
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
Expand All @@ -32,8 +32,8 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
TlsListener::new(SpawningHandshakes(tls_acceptor()), listener)
.for_each_concurrent(None, |s| async {
match s {
Ok(stream) => {
handle_stream(stream).await;
Ok((stream, remote_addr)) => {
handle_stream(stream, remote_addr).await;
}
Err(e) => {
eprintln!("Error: {:?}", e);
Expand Down
10 changes: 7 additions & 3 deletions examples/echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ mod tls_config;
use tls_config::tls_acceptor;

#[inline]
async fn handle_stream(stream: TlsStream<TcpStream>) {
async fn handle_stream(stream: TlsStream<TcpStream>, _remote_addr: SocketAddr) {
let (mut reader, mut writer) = split(stream);
match copy(&mut reader, &mut writer).await {
Ok(cnt) => eprintln!("Processed {} bytes", cnt),
Expand All @@ -41,10 +41,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
TlsListener::new(tls_acceptor(), listener)
.for_each_concurrent(None, |s| async {
match s {
Ok(stream) => {
handle_stream(stream).await;
Ok((stream, remote_addr)) => {
handle_stream(stream, remote_addr).await;
}
Err(e) => {
if let Some(remote_addr) = e.remote_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Error accepting connection: {:?}", e);
}
}
Expand Down
8 changes: 6 additions & 2 deletions examples/http-change-certificate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,22 @@ async fn main() {
tokio::select! {
conn = listener.accept() => {
match conn.expect("Tls listener stream should be infinite") {
Ok(conn) => {
Ok((conn, remote_addr)) => {
let http = http.clone();
let tx = tx.clone();
let counter = counter.clone();
tokio::spawn(async move {
let svc = service_fn(move |request| handle_request(tx.clone(), counter.clone(), request));
if let Err(err) = http.serve_connection(conn, svc).await {
eprintln!("Application error: {}", err);
eprintln!("Application error (client address: {remote_addr}): {}", err);
}
});
},
Err(e) => {
if let Some(remote_addr) = e.remote_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Bad connection: {}", e);
}
}
Expand Down
8 changes: 6 additions & 2 deletions examples/http-low-level.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@ async fn main() {
listener
.for_each(|r| async {
match r {
Ok(conn) => {
Ok((conn, remote_addr)) => {
let http = http.clone();
tokio::spawn(async move {
if let Err(err) = http.serve_connection(conn, svc).await {
eprintln!("Application error: {}", err);
eprintln!("[client {remote_addr}] Application error: {}", err);
}
});
}
Err(err) => {
if let Some(remote_addr) = err.remote_addr() {
eprint!("[client {remote_addr}] ");
}

eprintln!("Error accepting connection: {}", err);
}
}
Expand Down
20 changes: 11 additions & 9 deletions examples/http-stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures_util::stream::StreamExt;
use futures_util::stream::{StreamExt, TryStreamExt};
use hyper::server::accept;
use hyper::server::conn::AddrIncoming;
use hyper::service::{make_service_fn, service_fn};
Expand All @@ -22,14 +22,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
});

// This uses a filter to handle errors with connecting
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?).filter(|conn| {
if let Err(err) = conn {
eprintln!("Error: {:?}", err);
ready(false)
} else {
ready(true)
}
});
let incoming = TlsListener::new(tls_acceptor(), AddrIncoming::bind(&addr)?)
.filter(|conn| {
if let Err(err) = conn {
eprintln!("Error: {:?}", err);
ready(false)
} else {
ready(true)
}
})
.map_ok(|(conn, _remote_addr)| conn);

let server = Server::builder(accept::from_stream(incoming)).serve(new_svc);
server.await?;
Expand Down
25 changes: 19 additions & 6 deletions src/hyper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ use std::ops::{Deref, DerefMut};
impl AsyncAccept for AddrIncoming {
type Connection = AddrStream;
type Error = std::io::Error;
type Address = std::net::SocketAddr;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
<AddrIncoming as HyperAccept>::poll_accept(self, cx)
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
<AddrIncoming as HyperAccept>::poll_accept(self, cx).map_ok(|conn| {
let remote_addr = conn.remote_addr();
(conn, remote_addr)
})
}
}

Expand All @@ -22,6 +26,11 @@ pin_project! {
/// Unfortunately, it isn't possible to use a blanket impl, due to coherence rules.
/// At least until [RFC 1210](https://rust-lang.github.io/rfcs/1210-impl-specialization.html)
/// (specialization) is stabilized.
///
/// Note that, because `hyper::server::accept::Accept` does not expose the
/// remote address, the implementation of `AsyncAccept` for `WrappedAccept`
/// doesn't expose it either. That is, [`AsyncAccept::Address`] is `()` in
/// this case.
//#[cfg_attr(docsrs, doc(cfg(any(feature = "hyper-h1", feature = "hyper-h2"))))]
pub struct WrappedAccept<A> {
// sadly, pin-project-lite doesn't suport tuple structs :(
Expand All @@ -46,12 +55,16 @@ where
{
type Connection = A::Conn;
type Error = A::Error;
type Address = ();

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
self.project().inner.poll_accept(cx)
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
self.project()
.inner
.poll_accept(cx)
.map_ok(|conn| (conn, ()))
}
}

Expand Down Expand Up @@ -95,12 +108,12 @@ where
T: AsyncTls<A::Connection>,
{
type Conn = T::Stream;
type Error = Error<A::Error, T::Error>;
type Error = Error<A::Error, T::Error, A::Address>;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
self.poll_next(cx)
self.poll_next(cx).map_ok(|(conn, _)| conn)
}
}
110 changes: 95 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use futures_util::stream::{FuturesUnordered, Stream, StreamExt};
use pin_project_lite::pin_project;
#[cfg(feature = "rt")]
pub use spawning_handshake::SpawningHandshakes;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
Expand Down Expand Up @@ -67,14 +68,19 @@ pub trait AsyncTls<C: AsyncRead + AsyncWrite>: Clone {
pub trait AsyncAccept {
/// The type of the connection that is accepted.
type Connection: AsyncRead + AsyncWrite;
/// The type of the remote address, such as [`std::net::SocketAddr`].
///
/// If no remote address can be determined (such as for mock connections),
/// `()` or a similar dummy type can be used.
type Address: Debug;
/// The type of error that may be returned.
type Error;

/// Poll to accept the next connection.
fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>>;
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>>;

/// Return a new `AsyncAccept` that stops accepting connections after
/// `ender` completes.
Expand Down Expand Up @@ -126,7 +132,7 @@ pin_project! {
#[pin]
listener: A,
tls: T,
waiting: FuturesUnordered<Timeout<T::AcceptFuture>>,
waiting: FuturesUnordered<FutureWithExtraData<Timeout<T::AcceptFuture>, A::Address>>,
max_handshakes: usize,
timeout: Duration,
}
Expand All @@ -143,19 +149,35 @@ pub struct Builder<T> {
/// Wraps errors from either the listener or the TLS Acceptor
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum Error<LE: std::error::Error, TE: std::error::Error> {
// TODO: It would probably be more simple and more future-proof to use the
// `AsyncAccept` and `AsyncTls` implementations as the type parameters here, so
// that their associated types can be used in the fields
// (i.e. `error: A::Error, remote_addr: A::Address`), but that would require us
// to either hand-write `impl Debug` or use a proc-macro crate like
// `impl-tools` to derive `Debug` with custom bounds,
// due to https://github.com/rust-lang/rust/issues/26925
pub enum Error<LE: std::error::Error, TE: std::error::Error, A> {
/// An error that arose from the listener ([AsyncAccept::Error])
#[error("{0}")]
ListenerError(#[source] LE),
/// An error that occurred during the TLS accept handshake
#[error("{0}")]
TlsAcceptError(#[source] TE),
// TODO: is there any way we could include thee original connection, or maybe some
// info about it here?
#[error("{error}")]
#[non_exhaustive]
TlsAcceptError {
/// The error that occurred.
#[source]
error: TE,

/// The client's address and port.
remote_addr: A,
},
/// The TLS handshake timed out
#[error("Timeout during TLS handshake")]
#[non_exhaustive]
HandshakeTimeout {},
HandshakeTimeout {
/// The client's address and port.
remote_addr: A,
},
}

impl<A: AsyncAccept, T> TlsListener<A, T>
Expand Down Expand Up @@ -207,17 +229,19 @@ where
A::Error: std::error::Error,
T: AsyncTls<A::Connection>,
{
type Item = Result<T::Stream, Error<A::Error, T::Error>>;
type Item = Result<(T::Stream, A::Address), Error<A::Error, T::Error, A::Address>>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.project();

while this.waiting.len() < *this.max_handshakes {
match this.listener.as_mut().poll_accept(cx) {
Poll::Pending => break,
Poll::Ready(Some(Ok(conn))) => {
this.waiting
.push(timeout(*this.timeout, this.tls.accept(conn)));
Poll::Ready(Some(Ok((conn, address)))) => {
this.waiting.push(FutureWithExtraData::new(
timeout(*this.timeout, this.tls.accept(conn)),
address,
));
}
Poll::Ready(Some(Err(e))) => {
return Poll::Ready(Some(Err(Error::ListenerError(e))));
Expand All @@ -227,10 +251,15 @@ where
}

match this.waiting.poll_next_unpin(cx) {
Poll::Ready(Some(Ok(conn))) => Poll::Ready(Some(conn.map_err(Error::TlsAcceptError))),
Poll::Ready(Some((Ok(result), remote_addr))) => Poll::Ready(Some(match result {
Ok(conn) => Ok((conn, remote_addr)),
Err(error) => Err(Error::TlsAcceptError { error, remote_addr }),
})),
// The handshake timed out, try getting another connection from the
// queue
Poll::Ready(Some(Err(_))) => Poll::Ready(Some(Err(Error::HandshakeTimeout()))),
Poll::Ready(Some((Err(_), remote_addr))) => {
Poll::Ready(Some(Err(Error::HandshakeTimeout { remote_addr })))
}
_ => Poll::Pending,
}
}
Expand Down Expand Up @@ -337,6 +366,19 @@ impl<T> Builder<T> {
}
}

impl<LE: std::error::Error, TE: std::error::Error, A> Error<LE, TE, A> {
/// Returns the client's address and port, if known.
pub fn remote_addr(&self) -> Option<&A> {
match self {
Self::ListenerError(_) => None,

Self::TlsAcceptError { remote_addr, .. } | Self::HandshakeTimeout { remote_addr } => {
Some(remote_addr)
}
}
}
}

/// Create a new Builder for a TlsListener
///
/// `server_config` will be used to configure the TLS sessions.
Expand All @@ -361,11 +403,12 @@ pin_project! {
impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
type Connection = A::Connection;
type Error = A::Error;
type Address = A::Address;

fn poll_accept(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Connection, Self::Error>>> {
) -> Poll<Option<Result<(Self::Connection, Self::Address), Self::Error>>> {
let this = self.project();

match this.ender.poll(cx) {
Expand All @@ -374,3 +417,40 @@ impl<A: AsyncAccept, E: Future> AsyncAccept for Until<A, E> {
}
}
}

pin_project! {
struct FutureWithExtraData<Fut, X> {
#[pin]
future: Fut,
extra: Option<X>,
}
}

impl<Fut, X> FutureWithExtraData<Fut, X> {
fn new(future: Fut, extra: X) -> Self {
Self {
future,
extra: Some(extra),
}
}
}

impl<Fut, X> Future for FutureWithExtraData<Fut, X>
where
Fut: Future,
{
type Output = (Fut::Output, X);

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let extra = this.extra;

this.future.poll(cx).map(|output| {
let extra = extra
.take()
.expect("this future has already been polled to completion");

(output, extra)
})
}
}
Loading

0 comments on commit e920fbb

Please sign in to comment.