Skip to content

Commit

Permalink
util: Add a Collected body and combinator (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
LucioFranco committed Oct 19, 2022
1 parent dcb005a commit 26a9bee
Show file tree
Hide file tree
Showing 6 changed files with 382 additions and 5 deletions.
140 changes: 140 additions & 0 deletions http-body-util/src/collected.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
use std::{
convert::Infallible,
pin::Pin,
task::{Context, Poll},
};

use bytes::{Buf, Bytes};
use http::HeaderMap;
use http_body::{Body, Frame};

use crate::util::BufList;

/// A collected body produced by [`BodyExt::collect`] which collects all the DATA frames
/// and trailers.
#[derive(Debug)]
pub struct Collected<B> {
bufs: BufList<B>,
trailers: Option<HeaderMap>,
}

impl<B: Buf> Collected<B> {
/// If there is a trailers frame buffered, returns a reference to it.
///
/// Returns `None` if the body contained no trailers.
pub fn trailers(&self) -> Option<&HeaderMap> {
self.trailers.as_ref()
}

/// Aggregate this buffered into a [`Buf`].
pub fn aggregate(self) -> impl Buf {
self.bufs
}

/// Convert this body into a [`Bytes`].
pub fn to_bytes(mut self) -> Bytes {
self.bufs.copy_to_bytes(self.bufs.remaining())
}

pub(crate) fn push_frame(&mut self, frame: Frame<B>) {
if frame.is_data() {
let data = frame.into_data().unwrap();
self.bufs.push(data);
} else if frame.is_trailers() {
let trailers = frame.into_trailers().unwrap();

if let Some(current) = &mut self.trailers {
current.extend(trailers.into_iter());
} else {
self.trailers = Some(trailers);
}
}
}
}

impl<B: Buf> Body for Collected<B> {
type Data = B;
type Error = Infallible;

fn poll_frame(
mut self: Pin<&mut Self>,
_: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
let frame = if let Some(data) = self.bufs.pop() {
Frame::data(data)
} else if let Some(trailers) = self.trailers.take() {
Frame::trailers(trailers)
} else {
return Poll::Ready(None);
};

Poll::Ready(Some(Ok(frame)))
}
}

impl<B> Default for Collected<B> {
fn default() -> Self {
Self {
bufs: BufList::default(),
trailers: None,
}
}
}

impl<B> Unpin for Collected<B> {}

#[cfg(test)]
mod tests {
use std::convert::{Infallible, TryInto};

use futures_util::stream;

use crate::{BodyExt, Full, StreamBody};

use super::*;

#[tokio::test]
async fn full_body() {
let body = Full::new(&b"hello"[..]);

let buffered = body.collect().await.unwrap();

let mut buf = buffered.to_bytes();

assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], &b"hello"[..]);
}

#[tokio::test]
async fn segmented_body() {
let bufs = [&b"hello"[..], &b"world"[..], &b"!"[..]];

let body = StreamBody::new(stream::iter(bufs.map(Frame::data).map(Ok::<_, Infallible>)));

let buffered = body.collect().await.unwrap();

let mut buf = buffered.to_bytes();

assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], b"helloworld!");
}

#[tokio::test]
async fn trailers() {
let mut trailers = HeaderMap::new();
trailers.insert("this", "a trailer".try_into().unwrap());
let bufs = [
Frame::data(&b"hello"[..]),
Frame::data(&b"world!"[..]),
Frame::trailers(trailers.clone()),
];

let body = StreamBody::new(stream::iter(bufs.map(Ok::<_, Infallible>)));

let buffered = body.collect().await.unwrap();

assert_eq!(&trailers, buffered.trailers().unwrap());

let mut buf = buffered.to_bytes();

assert_eq!(&buf.copy_to_bytes(buf.remaining())[..], b"helloworld!");
}
}
38 changes: 38 additions & 0 deletions http-body-util/src/combinators/collect.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
use std::{
pin::Pin,
task::{Context, Poll},
};

use futures_util::Future;
use http_body::Body;
use pin_project_lite::pin_project;

pin_project! {
/// Future that resolves into a `Collected`.
pub struct Collect<T: ?Sized> {
#[pin]
pub(crate) body: T
}
}

impl<T: Body + Unpin + ?Sized> Future for Collect<T> {
type Output = Result<crate::Collected<T::Data>, T::Error>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> std::task::Poll<Self::Output> {
let mut collected = crate::Collected::default();

let mut me = self.project();

loop {
let frame = futures_util::ready!(Pin::new(&mut me.body).poll_frame(cx));

let frame = if let Some(frame) = frame {
frame?
} else {
return Poll::Ready(Ok(collected));
};

collected.push_frame(frame);
}
}
}
2 changes: 2 additions & 0 deletions http-body-util/src/combinators/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
//! Combinators for the `Body` trait.

mod box_body;
mod collect;
mod frame;
mod map_err;
mod map_frame;

pub use self::{
box_body::{BoxBody, UnsyncBoxBody},
collect::Collect,
frame::Frame,
map_err::MapErr,
map_frame::MapFrame,
Expand Down
14 changes: 14 additions & 0 deletions http-body-util/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@
//!
//! [`Empty`] and [`Full`] provide simple implementations.

mod collected;
pub mod combinators;
mod either;
mod empty;
mod full;
mod limited;
mod stream;

mod util;

use self::combinators::{BoxBody, MapErr, MapFrame, UnsyncBoxBody};

pub use self::collected::Collected;
pub use self::either::Either;
pub use self::empty::Empty;
pub use self::full::Full;
Expand Down Expand Up @@ -70,6 +75,15 @@ pub trait BodyExt: http_body::Body {
{
UnsyncBoxBody::new(self)
}

/// Turn this body into [`Collected`] body which will collect all the DATA frames
/// and trailers.
fn collect(self) -> combinators::Collect<Self>
where
Self: Sized,
{
combinators::Collect { body: self }
}
}

impl<T: ?Sized> BodyExt for T where T: http_body::Body {}
153 changes: 153 additions & 0 deletions http-body-util/src/util.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
use std::collections::VecDeque;
use std::io::IoSlice;

use bytes::{Buf, BufMut, Bytes, BytesMut};

#[derive(Debug)]
pub(crate) struct BufList<T> {
bufs: VecDeque<T>,
}

impl<T: Buf> BufList<T> {
#[inline]
pub(crate) fn push(&mut self, buf: T) {
debug_assert!(buf.has_remaining());
self.bufs.push_back(buf);
}

#[inline]
pub(crate) fn pop(&mut self) -> Option<T> {
self.bufs.pop_front()
}
}

impl<T: Buf> Buf for BufList<T> {
#[inline]
fn remaining(&self) -> usize {
self.bufs.iter().map(|buf| buf.remaining()).sum()
}

#[inline]
fn chunk(&self) -> &[u8] {
self.bufs.front().map(Buf::chunk).unwrap_or_default()
}

#[inline]
fn advance(&mut self, mut cnt: usize) {
while cnt > 0 {
{
let front = &mut self.bufs[0];
let rem = front.remaining();
if rem > cnt {
front.advance(cnt);
return;
} else {
front.advance(rem);
cnt -= rem;
}
}
self.bufs.pop_front();
}
}

#[inline]
fn chunks_vectored<'t>(&'t self, dst: &mut [IoSlice<'t>]) -> usize {
if dst.is_empty() {
return 0;
}
let mut vecs = 0;
for buf in &self.bufs {
vecs += buf.chunks_vectored(&mut dst[vecs..]);
if vecs == dst.len() {
break;
}
}
vecs
}

#[inline]
fn copy_to_bytes(&mut self, len: usize) -> Bytes {
// Our inner buffer may have an optimized version of copy_to_bytes, and if the whole
// request can be fulfilled by the front buffer, we can take advantage.
match self.bufs.front_mut() {
Some(front) if front.remaining() == len => {
let b = front.copy_to_bytes(len);
self.bufs.pop_front();
b
}
Some(front) if front.remaining() > len => front.copy_to_bytes(len),
_ => {
assert!(len <= self.remaining(), "`len` greater than remaining");
let mut bm = BytesMut::with_capacity(len);
bm.put(self.take(len));
bm.freeze()
}
}
}
}

impl<T> Default for BufList<T> {
fn default() -> Self {
BufList {
bufs: VecDeque::new(),
}
}
}

#[cfg(test)]
mod tests {
use std::ptr;

use super::*;

fn hello_world_buf() -> BufList<Bytes> {
BufList {
bufs: vec![Bytes::from("Hello"), Bytes::from(" "), Bytes::from("World")].into(),
}
}

#[test]
fn to_bytes_shorter() {
let mut bufs = hello_world_buf();
let old_ptr = bufs.chunk().as_ptr();
let start = bufs.copy_to_bytes(4);
assert_eq!(start, "Hell");
assert!(ptr::eq(old_ptr, start.as_ptr()));
assert_eq!(bufs.chunk(), b"o");
assert!(ptr::eq(old_ptr.wrapping_add(4), bufs.chunk().as_ptr()));
assert_eq!(bufs.remaining(), 7);
}

#[test]
fn to_bytes_eq() {
let mut bufs = hello_world_buf();
let old_ptr = bufs.chunk().as_ptr();
let start = bufs.copy_to_bytes(5);
assert_eq!(start, "Hello");
assert!(ptr::eq(old_ptr, start.as_ptr()));
assert_eq!(bufs.chunk(), b" ");
assert_eq!(bufs.remaining(), 6);
}

#[test]
fn to_bytes_longer() {
let mut bufs = hello_world_buf();
let start = bufs.copy_to_bytes(7);
assert_eq!(start, "Hello W");
assert_eq!(bufs.remaining(), 4);
}

#[test]
fn one_long_buf_to_bytes() {
let mut buf = BufList::default();
buf.push(b"Hello World" as &[_]);
assert_eq!(buf.copy_to_bytes(5), "Hello");
assert_eq!(buf.chunk(), b" World");
}

#[test]
#[should_panic(expected = "`len` greater than remaining")]
fn buf_to_bytes_too_many() {
hello_world_buf().copy_to_bytes(42);
}
}
Loading

0 comments on commit 26a9bee

Please sign in to comment.