diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index 4cf7d289..7e7844c1 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -184,6 +184,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" + [[package]] name = "bzip2" version = "0.4.4" @@ -373,6 +379,12 @@ version = "0.32.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e629b9b98ef3dd8afe6ca2bd0f89306cec16d43d907889945bc5d6687f2f13c7" +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "hex-conservative" version = "0.2.1" @@ -526,6 +538,18 @@ dependencies = [ "adler2", ] +[[package]] +name = "mio" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "80e04d1dcff3aae0704555fe5fee3bcfaf3d1fdf8a7e521d5b9d2b42acb52cec" +dependencies = [ + "hermit-abi", + "libc", + "wasi", + "windows-sys 0.52.0", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -798,6 +822,16 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "socket2" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "socks" version = "0.3.4" @@ -869,7 +903,24 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2209a14885b74764cce87ffa777ffa1b8ce81a3f3166c6f886b83337fe7e077f" dependencies = [ "backtrace", + "bytes", + "libc", + "mio", "pin-project-lite", + "socket2", + "tokio-macros", + "windows-sys 0.52.0", +] + +[[package]] +name = "tokio-macros" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] diff --git a/Cargo-recent.lock b/Cargo-recent.lock index 7bbcb62d..7f0f92d9 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -184,6 +184,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" + [[package]] name = "bzip2" version = "0.4.4" @@ -827,6 +833,16 @@ version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" +[[package]] +name = "socket2" +version = "0.5.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c970269d99b64e60ec3bd6ad27270092a5394c4e309314b18ae3fe575695fbe8" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + [[package]] name = "socks" version = "0.3.4" @@ -898,11 +914,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc3a2344dafbe23a245241fe8b09735b521110d30fcefbbd5feb1797ca35d17" dependencies = [ "backtrace", + "bytes", "io-uring", "libc", "mio", "pin-project-lite", "slab", + "socket2", + "tokio-macros", + "windows-sys 0.52.0", +] + +[[package]] +name = "tokio-macros" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] diff --git a/bitreq/Cargo.toml b/bitreq/Cargo.toml index 6c38a85d..89d28000 100644 --- a/bitreq/Cargo.toml +++ b/bitreq/Cargo.toml @@ -28,12 +28,13 @@ webpki-roots = { version = "0.25.2", optional = true } rustls-webpki = { version = "0.101.0", optional = true } log = { version = "0.4.0", optional = true } # For the async feature: -tokio = { version = "1.0", features = ["rt", "rt-multi-thread"], optional = true } +tokio = { version = "1.0", features = ["rt", "net", "io-util", "time"], optional = true } tokio-rustls = { version = "0.24", optional = true } [dev-dependencies] tiny_http = "0.12" chrono = "0.4.0" +tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "time"] } [package.metadata.docs.rs] features = ["json-using-serde", "proxy", "https"] diff --git a/bitreq/src/connection.rs b/bitreq/src/connection.rs index 87e053b1..42b1fd6f 100644 --- a/bitreq/src/connection.rs +++ b/bitreq/src/connection.rs @@ -1,10 +1,23 @@ use core::time::Duration; use std::env; +#[cfg(feature = "async")] +use std::future::Future; use std::io::{self, Read, Write}; use std::net::{TcpStream, ToSocketAddrs}; +#[cfg(feature = "async")] +use std::pin::Pin; use std::time::Instant; +#[cfg(all(feature = "async", feature = "proxy"))] +use tokio::io::AsyncReadExt; +#[cfg(feature = "async")] +use tokio::io::{AsyncRead, AsyncWriteExt}; +#[cfg(feature = "async")] +use tokio::net::TcpStream as AsyncTcpStream; + use crate::request::ParsedRequest; +#[cfg(feature = "async")] +use crate::Response; use crate::{Error, Method, ResponseLazy}; type UnsecuredStream = TcpStream; @@ -18,6 +31,8 @@ pub(crate) enum HttpStream { Unsecured(UnsecuredStream, Option), #[cfg(feature = "rustls")] Secured(Box, Option), + #[cfg(feature = "async")] + Buffer(std::io::Cursor>), } impl HttpStream { @@ -29,6 +44,11 @@ impl HttpStream { fn create_secured(reader: SecuredStream, timeout_at: Option) -> HttpStream { HttpStream::Secured(Box::new(reader), timeout_at) } + + #[cfg(feature = "async")] + pub(crate) fn create_buffer(buffer: Vec) -> HttpStream { + HttpStream::Buffer(std::io::Cursor::new(buffer)) + } } fn timeout_err() -> io::Error { @@ -64,6 +84,8 @@ impl Read for HttpStream { timeout(inner.get_ref(), *timeout_at)?; inner.read(buf) } + #[cfg(feature = "async")] + HttpStream::Buffer(cursor) => std::io::Read::read(cursor, buf), }; match result { Err(e) if e.kind() == io::ErrorKind::WouldBlock => { @@ -75,6 +97,46 @@ impl Read for HttpStream { } } +#[cfg(feature = "async")] +type AsyncUnsecuredStream = AsyncTcpStream; + +#[cfg(feature = "async-https")] +type AsyncSecuredStream = rustls_stream::AsyncSecuredStream; + +#[cfg(feature = "async")] +pub(crate) enum AsyncHttpStream { + Unsecured(AsyncUnsecuredStream), + #[cfg(feature = "async-https")] + Secured(Box), +} + +#[cfg(feature = "async")] +impl AsyncHttpStream { + fn create_unsecured(stream: AsyncUnsecuredStream) -> AsyncHttpStream { + AsyncHttpStream::Unsecured(stream) + } + + #[cfg(feature = "async-https")] + fn create_secured(stream: AsyncSecuredStream) -> AsyncHttpStream { + AsyncHttpStream::Secured(Box::new(stream)) + } +} + +#[cfg(feature = "async")] +impl AsyncRead for AsyncHttpStream { + fn poll_read( + mut self: core::pin::Pin<&mut Self>, + cx: &mut core::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> core::task::Poll> { + match &mut *self { + AsyncHttpStream::Unsecured(inner) => core::pin::Pin::new(inner).poll_read(cx, buf), + #[cfg(feature = "async-https")] + AsyncHttpStream::Secured(inner) => core::pin::Pin::new(inner).poll_read(cx, buf), + } + } +} + /// An async connection to the server for sending /// [`Request`](struct.Request.html)s. #[cfg(feature = "async")] @@ -95,25 +157,130 @@ impl AsyncConnection { AsyncConnection { request, timeout_at } } + /// Asynchronously connect to the server. + async fn connect(&self) -> Result { + let tcp_connect = |host: String, port: u32| async move { + let addrs = tokio::net::lookup_host((host.as_str(), port as u16)) + .await + .map_err(Error::IoError)?; + let addrs: Vec<_> = addrs.collect(); + let addrs_count = addrs.len(); + + if addrs.is_empty() { + return Err(Error::AddressNotFound); + } + + // Try all resolved addresses. Return the first one to which we could connect. If all + // failed return the last error encountered. + for (i, addr) in addrs.iter().enumerate() { + match AsyncTcpStream::connect(addr).await { + Ok(s) => return Ok(s), + Err(e) => + if i == addrs_count - 1 { + return Err(Error::IoError(e)); + }, + } + } + + Err(Error::AddressNotFound) + }; + + #[cfg(feature = "proxy")] + match &self.request.config.proxy { + Some(proxy) => { + // do proxy things + let mut tcp = tcp_connect(proxy.server.clone(), proxy.port).await?; + + let proxy_request = format!("{}", proxy.connect(&self.request)); + tcp.write_all(proxy_request.as_bytes()).await?; + tcp.flush().await?; + + let mut proxy_response = Vec::new(); + let mut buf = vec![0; 256]; + loop { + let n = tcp.read(&mut buf).await?; + proxy_response.extend_from_slice(&buf[..n]); + if n < 256 { + break; + } + } + + crate::Proxy::verify_response(&proxy_response)?; + + Ok(tcp) + } + None => tcp_connect(self.request.url.host.clone(), self.request.url.port.port()).await, + } + + #[cfg(not(feature = "proxy"))] + tcp_connect(self.request.url.host.clone(), self.request.url.port.port()).await + } + /// Sends the [`Request`](struct.Request.html) asynchronously using HTTPS. #[cfg(feature = "async-https")] - pub(crate) async fn send_https(self) -> Result { - // Use spawn_blocking to run the sync HTTPS code in a thread pool - let sync_conn = Connection { request: self.request, timeout_at: self.timeout_at }; + pub(crate) async fn send_https(self) -> Result { + let timeout = self.timeout_at; + let future = async move { + let is_head = self.request.config.method == Method::Head; + let secured_stream = rustls_stream::create_async_secured_stream(&self).await?; - tokio::task::spawn_blocking(move || sync_conn.send_https()) - .await - .map_err(|e| Error::IoError(io::Error::new(io::ErrorKind::Other, e)))? + #[cfg(feature = "log")] + log::trace!("Reading HTTPS response from {}.", self.request.url.host); + let response = Response::create_async( + secured_stream, + is_head, + self.request.config.max_headers_size, + self.request.config.max_status_line_len, + ) + .await?; + + async_handle_redirects(self, response).await + }; + if let Some(timeout_at) = timeout { + tokio::time::timeout_at(timeout_at.into(), future) + .await + .unwrap_or(Err(Error::IoError(timeout_err()))) + } else { + future.await + } } /// Sends the [`Request`](struct.Request.html) asynchronously using HTTP. - pub(crate) async fn send(self) -> Result { - // Use spawn_blocking to run the sync HTTP code in a thread pool - let sync_conn = Connection { request: self.request, timeout_at: self.timeout_at }; + pub(crate) async fn send(self) -> Result { + let timeout = self.timeout_at; + let future = async move { + let is_head = self.request.config.method == Method::Head; + let bytes = self.request.as_bytes(); - tokio::task::spawn_blocking(move || sync_conn.send()) - .await - .map_err(|e| Error::IoError(io::Error::new(io::ErrorKind::Other, e)))? + #[cfg(feature = "log")] + log::trace!("Establishing TCP connection to {}.", self.request.url.host); + let mut tcp = self.connect().await?; + + // Send request + #[cfg(feature = "log")] + log::trace!("Writing HTTP request."); + tcp.write_all(&bytes).await?; + + // Receive response + #[cfg(feature = "log")] + log::trace!("Reading HTTP response."); + let stream = AsyncHttpStream::create_unsecured(tcp); + let response = Response::create_async( + stream, + is_head, + self.request.config.max_headers_size, + self.request.config.max_status_line_len, + ) + .await?; + async_handle_redirects(self, response).await + }; + if let Some(timeout_at) = timeout { + tokio::time::timeout_at(timeout_at.into(), future) + .await + .unwrap_or(Err(Error::IoError(timeout_err()))) + } else { + future.await + } } } @@ -275,41 +442,83 @@ fn handle_redirects( } } -enum NextHop { - Redirect(Result), - Destination(Connection), +#[cfg(feature = "async")] +fn async_handle_redirects( + connection: AsyncConnection, + mut response: Response, +) -> Pin> + Send>> { + Box::pin(async move { + let status_code = response.status_code; + let url = response.headers.get("location"); + match async_get_redirect(connection, status_code, url) { + NextHopAsync::Redirect(connection) => { + let connection = connection?; + if connection.request.url.https { + #[cfg(not(feature = "async-https"))] + return Err(Error::HttpsFeatureNotEnabled); + #[cfg(feature = "async-https")] + return connection.send_https().await; + } else { + connection.send().await + } + } + NextHopAsync::Destination(connection) => { + let dst_url = connection.request.url; + dst_url.write_base_url_to(&mut response.url).unwrap(); + dst_url.write_resource_to(&mut response.url).unwrap(); + Ok(response) + } + } + }) } -fn get_redirect(mut connection: Connection, status_code: i32, url: Option<&String>) -> NextHop { - match status_code { - 301 | 302 | 303 | 307 => { - let url = match url { - Some(url) => url, - None => return NextHop::Redirect(Err(Error::RedirectLocationMissing)), - }; - #[cfg(feature = "log")] - log::debug!("Redirecting ({}) to: {}", status_code, url); - - match connection.request.redirect_to(url.as_str()) { - Ok(()) => { - if status_code == 303 { - match connection.request.config.method { - Method::Post | Method::Put | Method::Delete => { - connection.request.config.method = Method::Get; +macro_rules! redirect_utils { + ($get_redirect: ident, $NextHop: ident, $Connection: ident, $Response: ident) => { + enum $NextHop { + Redirect(Result<$Connection, Error>), + Destination($Connection), + } + + fn $get_redirect( + mut connection: $Connection, + status_code: i32, + url: Option<&String>, + ) -> $NextHop { + match status_code { + 301 | 302 | 303 | 307 => { + let url = match url { + Some(url) => url, + None => return $NextHop::Redirect(Err(Error::RedirectLocationMissing)), + }; + #[cfg(feature = "log")] + log::debug!("Redirecting ({}) to: {}", status_code, url); + + match connection.request.redirect_to(url.as_str()) { + Ok(()) => { + if status_code == 303 { + match connection.request.config.method { + Method::Post | Method::Put | Method::Delete => { + connection.request.config.method = Method::Get; + } + _ => {} + } } - _ => {} + + $NextHop::Redirect(Ok(connection)) } + Err(err) => $NextHop::Redirect(Err(err)), } - - NextHop::Redirect(Ok(connection)) } - Err(err) => NextHop::Redirect(Err(err)), + _ => $NextHop::Destination(connection), } } - _ => NextHop::Destination(connection), - } + }; } +redirect_utils!(get_redirect, NextHop, Connection, ResponseLazy); +#[cfg(feature = "async")] +redirect_utils!(async_get_redirect, NextHopAsync, AsyncConnection, Response); + /// Enforce the timeout by running the function in a new thread and /// parking the current one with a timeout. /// diff --git a/bitreq/src/connection/rustls_stream.rs b/bitreq/src/connection/rustls_stream.rs index 23854501..e0126e71 100644 --- a/bitreq/src/connection/rustls_stream.rs +++ b/bitreq/src/connection/rustls_stream.rs @@ -8,9 +8,15 @@ use std::net::TcpStream; use std::sync::OnceLock; use rustls::{self, ClientConfig, ClientConnection, RootCertStore, ServerName, StreamOwned}; +#[cfg(feature = "async-https")] +use tokio::io::AsyncWriteExt; +#[cfg(feature = "async-https")] +use tokio_rustls::{client::TlsStream, TlsConnector}; #[cfg(feature = "rustls-webpki")] use webpki_roots::TLS_SERVER_ROOTS; +#[cfg(feature = "async-https")] +use super::{AsyncConnection, AsyncHttpStream}; use super::{Connection, HttpStream}; use crate::Error; @@ -48,7 +54,7 @@ fn build_client_config() -> Arc { Arc::new(config) } -pub fn create_secured_stream(conn: &Connection) -> Result { +pub(super) fn create_secured_stream(conn: &Connection) -> Result { // Rustls setup #[cfg(feature = "log")] log::trace!("Setting up TLS parameters for {}.", conn.request.url.host); @@ -75,3 +81,43 @@ pub fn create_secured_stream(conn: &Connection) -> Result { Ok(HttpStream::create_secured(tls, conn.timeout_at)) } + +// Async TLS implementation + +#[cfg(feature = "async-https")] +pub type AsyncSecuredStream = TlsStream; + +#[cfg(feature = "async-https")] +pub(super) async fn create_async_secured_stream( + conn: &AsyncConnection, +) -> Result { + // Rustls setup + #[cfg(feature = "log")] + log::trace!("Setting up TLS parameters for {}.", conn.request.url.host); + let dns_name = match ServerName::try_from(&*conn.request.url.host) { + Ok(result) => result, + Err(err) => return Err(Error::IoError(io::Error::new(io::ErrorKind::Other, err))), + }; + + let connector = TlsConnector::from(CONFIG.get_or_init(build_client_config).clone()); + + // Connect + #[cfg(feature = "log")] + log::trace!("Establishing TCP connection to {}.", conn.request.url.host); + let tcp = conn.connect().await?; + + // Establish TLS connection + #[cfg(feature = "log")] + log::trace!("Establishing TLS session to {}.", conn.request.url.host); + let mut tls = connector + .connect(dns_name, tcp) + .await + .map_err(|e| Error::IoError(io::Error::new(io::ErrorKind::Other, e)))?; + + // Send request + #[cfg(feature = "log")] + log::trace!("Writing HTTPS request to {}.", conn.request.url.host); + tls.write_all(&conn.request.as_bytes()).await?; + + Ok(AsyncHttpStream::create_secured(tls)) +} diff --git a/bitreq/src/request.rs b/bitreq/src/request.rs index eb375a8b..6b7353ce 100644 --- a/bitreq/src/request.rs +++ b/bitreq/src/request.rs @@ -314,41 +314,32 @@ impl Request { if parsed_request.url.https { #[cfg(feature = "async-https")] { - let is_head = parsed_request.config.method == Method::Head; - let response = AsyncConnection::new(parsed_request).send_https().await?; - Response::create(response, is_head) + AsyncConnection::new(parsed_request).send_https().await } #[cfg(not(feature = "async-https"))] { Err(Error::HttpsFeatureNotEnabled) } } else { - let is_head = parsed_request.config.method == Method::Head; - let response = AsyncConnection::new(parsed_request).send().await?; - Response::create(response, is_head) + AsyncConnection::new(parsed_request).send().await } } - /// Sends this request to the host asynchronously, loaded lazily. + /// Sends this request to the host asynchronously, "loaded lazily". + /// + /// Note that due to API limitations the response is not actually loaded lazily - it is loaded + /// immediately and then can be re-read from the response. In a future version an + /// `AsyncResponseLazy` will be added and lazy loading support will be added. + /// + /// Until then, you should use [`Self::send_async`]. /// /// # Errors /// /// See [`send_async`](struct.Request.html#method.send_async). #[cfg(feature = "async")] pub async fn send_lazy_async(self) -> Result { - let parsed_request = ParsedRequest::new(self)?; - if parsed_request.url.https { - #[cfg(feature = "async-https")] - { - AsyncConnection::new(parsed_request).send_https().await - } - #[cfg(not(feature = "async-https"))] - { - Err(Error::HttpsFeatureNotEnabled) - } - } else { - AsyncConnection::new(parsed_request).send().await - } + let response = self.send_async().await?; + Ok(ResponseLazy::dummy_from_response(response)) } } diff --git a/bitreq/src/response.rs b/bitreq/src/response.rs index fe9f674d..96e13e37 100644 --- a/bitreq/src/response.rs +++ b/bitreq/src/response.rs @@ -1,8 +1,13 @@ use alloc::collections::BTreeMap; use core::str; +#[cfg(feature = "async")] +use std::future::Future; #[cfg(feature = "std")] use std::io::{self, BufReader, Bytes, Read}; +#[cfg(feature = "async")] +use tokio::io::{AsyncRead, AsyncReadExt}; + #[cfg(feature = "std")] use crate::connection::HttpStream; use crate::Error; @@ -62,6 +67,71 @@ impl Response { Ok(Response { status_code, reason_phrase, headers, url, body }) } + #[cfg(feature = "async")] + /// Fully read a [`Response`] from an async stream. + /// + /// When this crate was originally made "async", it actually just spawned sync requests on + /// background threads and waited on their completion rather than actually doing async reads. + /// In order to avoid changing the API while fixing this, we read the full response but then + /// return a "lazy" response that has the full contents pre-read. + pub(crate) async fn create_async( + stream: R, + is_head: bool, + max_headers_size: Option, + max_status_line_len: Option, + ) -> Result { + use HttpStreamState::*; + + let mut stream = tokio::io::BufReader::with_capacity(BACKING_READ_BUFFER_LENGTH, stream); + + let ResponseMetadata { + status_code, + reason_phrase, + mut headers, + state, + max_trailing_headers_size, + } = read_metadata_async(&mut stream, max_headers_size, max_status_line_len).await?; + + let mut body = Vec::new(); + if !is_head && status_code != 204 && status_code != 304 { + match state { + EndOnClose => { + while let Some(byte_result) = read_until_closed_async(&mut stream).await { + let (byte, length) = byte_result?; + body.reserve(length); + body.push(byte); + } + } + ContentLength(mut length) => { + while let Some(byte_result) = + read_with_content_length_async(&mut stream, &mut length).await + { + let (byte, expected_length) = byte_result?; + body.reserve(expected_length); + body.push(byte); + } + } + Chunked(mut expecting_chunks, mut chunk_length, mut content_length) => + while let Some(byte_result) = read_chunked_async( + &mut stream, + &mut headers, + &mut expecting_chunks, + &mut chunk_length, + &mut content_length, + max_trailing_headers_size, + ) + .await + { + let (byte, length) = byte_result?; + body.reserve(length); + body.push(byte); + }, + } + } + + Ok(Response { status_code, reason_phrase, headers, url: String::new(), body }) + } + /// Returns the body as an `&str`. /// /// # Errors @@ -251,6 +321,20 @@ impl ResponseLazy { max_trailing_headers_size, }) } + + #[cfg(feature = "async")] + pub(crate) fn dummy_from_response(response: Response) -> ResponseLazy { + let http_stream = HttpStream::create_buffer(response.body); + ResponseLazy { + status_code: response.status_code, + reason_phrase: response.reason_phrase, + headers: response.headers, + url: response.url, + stream: BufReader::with_capacity(1, http_stream).bytes(), + state: HttpStreamState::EndOnClose, + max_trailing_headers_size: None, + } + } } #[cfg(feature = "std")] @@ -301,140 +385,6 @@ impl Read for ResponseLazy { } } -#[cfg(feature = "std")] -fn read_until_closed(bytes: &mut HttpStreamBytes) -> Option<::Item> { - if let Some(byte) = bytes.next() { - match byte { - Ok(byte) => Some(Ok((byte, 1))), - Err(err) => Some(Err(Error::IoError(err))), - } - } else { - None - } -} - -#[cfg(feature = "std")] -fn read_with_content_length( - bytes: &mut HttpStreamBytes, - content_length: &mut usize, -) -> Option<::Item> { - if *content_length > 0 { - *content_length -= 1; - - if let Some(byte) = bytes.next() { - match byte { - // Cap Content-Length to 16KiB, to avoid out-of-memory issues. - Ok(byte) => return Some(Ok((byte, (*content_length).min(MAX_CONTENT_LENGTH) + 1))), - Err(err) => return Some(Err(Error::IoError(err))), - } - } - } - None -} - -#[cfg(feature = "std")] -fn read_trailers( - bytes: &mut HttpStreamBytes, - headers: &mut BTreeMap, - mut max_headers_size: Option, -) -> Result<(), Error> { - loop { - let trailer_line = read_line(bytes, max_headers_size, Error::HeadersOverflow)?; - if let Some(ref mut max_headers_size) = max_headers_size { - *max_headers_size -= trailer_line.len() + 2; - } - if let Some((header, value)) = parse_header(trailer_line) { - headers.insert(header, value); - } else { - break; - } - } - Ok(()) -} - -#[cfg(feature = "std")] -fn read_chunked( - bytes: &mut HttpStreamBytes, - headers: &mut BTreeMap, - expecting_more_chunks: &mut bool, - chunk_length: &mut usize, - content_length: &mut usize, - max_trailing_headers_size: Option, -) -> Option<::Item> { - if !*expecting_more_chunks && *chunk_length == 0 { - return None; - } - - if *chunk_length == 0 { - // Max length of the chunk length line is 1KB: not too long to - // take up much memory, long enough to tolerate some chunk - // extensions (which are ignored). - - // Get the size of the next chunk - let length_line = match read_line(bytes, Some(1024), Error::MalformedChunkLength) { - Ok(line) => line, - Err(err) => return Some(Err(err)), - }; - - // Note: the trim() and check for empty lines shouldn't be - // needed according to the RFC, but we might as well, it's a - // small change and it fixes a few servers. - let incoming_length = if length_line.is_empty() { - 0 - } else { - let length = if let Some(i) = length_line.find(';') { - length_line[..i].trim() - } else { - length_line.trim() - }; - match usize::from_str_radix(length, 16) { - Ok(length) => length, - Err(_) => return Some(Err(Error::MalformedChunkLength)), - } - }; - - if incoming_length == 0 { - if let Err(err) = read_trailers(bytes, headers, max_trailing_headers_size) { - return Some(Err(err)); - } - - *expecting_more_chunks = false; - headers.insert("content-length".to_string(), (*content_length).to_string()); - headers.remove("transfer-encoding"); - return None; - } - *chunk_length = incoming_length; - *content_length += incoming_length; - } - - if *chunk_length > 0 { - *chunk_length -= 1; - if let Some(byte) = bytes.next() { - match byte { - Ok(byte) => { - // If we're at the end of the chunk... - if *chunk_length == 0 { - //...read the trailing \r\n of the chunk, and - // possibly return an error instead. - - // TODO: Maybe this could be written in a way - // that doesn't discard the last ok byte if - // the \r\n reading fails? - if let Err(err) = read_line(bytes, Some(2), Error::MalformedChunkEnd) { - return Some(Err(err)); - } - } - - return Some(Ok((byte, (*chunk_length).min(MAX_CONTENT_LENGTH) + 1))); - } - Err(err) => return Some(Err(Error::IoError(err))), - } - } - } - - None -} - #[cfg(feature = "std")] enum HttpStreamState { // No Content-Length, and Transfer-Encoding != chunked, so we just @@ -464,96 +414,259 @@ struct ResponseMetadata { max_trailing_headers_size: Option, } -#[cfg(feature = "std")] -fn read_metadata( - stream: &mut HttpStreamBytes, - mut max_headers_size: Option, - max_status_line_len: Option, -) -> Result { - let line = read_line(stream, max_status_line_len, Error::StatusLineOverflow)?; - let (status_code, reason_phrase) = parse_status_line(&line); - - let mut headers = BTreeMap::new(); - loop { - let line = read_line(stream, max_headers_size, Error::HeadersOverflow)?; - if line.is_empty() { - // Body starts here - break; - } - if let Some(ref mut max_headers_size) = max_headers_size { - *max_headers_size -= line.len() + 2; - } - if let Some(header) = parse_header(line) { - headers.insert(header.0, header.1); - } +macro_rules! maybe_await { + ($e: expr, await) => { + $e.await + }; + ($e: expr,) => { + $e + }; +} + +#[cfg(feature = "async")] +/// We need to mungle [`AsyncRead`] to look like an iterator, which we do here. +trait AsyncIteratorReadExt { + fn next(&mut self) -> impl Future>>; +} + +#[cfg(feature = "async")] +impl AsyncIteratorReadExt for T { + fn next(&mut self) -> impl Future>> { + async { Some(self.read_u8().await) } } +} - let mut chunked = false; - let mut content_length = None; - for (header, value) in &headers { - // Handle the Transfer-Encoding header - if header.to_lowercase().trim() == "transfer-encoding" - && value.to_lowercase().trim() == "chunked" - { - chunked = true; +macro_rules! define_read_methods { + (($read_until_closed: ident, $read_with_content_length: ident, $read_trailers: ident, $read_chunked: ident, $read_metadata: ident, $read_line: ident)<$($arg: ident : $($argty: path $(|)?)*),*>, $stream_type: ident $(, $async: tt, $await: tt)?) => { + $($async)? fn $read_until_closed<$($arg: $($argty +)*),*>( + bytes: &mut $stream_type, + ) -> Option<::Item> { + if let Some(byte) = maybe_await!(bytes.next(), $($await)?) { + match byte { + Ok(byte) => Some(Ok((byte, 1))), + Err(err) => Some(Err(Error::IoError(err))), + } + } else { + None + } } - // Handle the Content-Length header - if header.to_lowercase().trim() == "content-length" { - match str::parse::(value.trim()) { - Ok(length) => content_length = Some(length), - Err(_) => return Err(Error::MalformedContentLength), + $($async)? fn $read_with_content_length<$($arg: $($argty +)*),*>( + bytes: &mut $stream_type, + content_length: &mut usize, + ) -> Option<::Item> { + if *content_length > 0 { + *content_length -= 1; + + if let Some(byte) = maybe_await!(bytes.next(), $($await)?) { + match byte { + // Cap Content-Length to 16KiB, to avoid out-of-memory issues. + Ok(byte) => return Some(Ok((byte, (*content_length).min(MAX_CONTENT_LENGTH) + 1))), + Err(err) => return Some(Err(Error::IoError(err))), + } + } } + None } - } - let state = if chunked { - HttpStreamState::Chunked(true, 0, 0) - } else if let Some(length) = content_length { - HttpStreamState::ContentLength(length) - } else { - HttpStreamState::EndOnClose - }; + $($async)? fn $read_trailers<$($arg: $($argty +)*),*>( + bytes: &mut $stream_type, + headers: &mut BTreeMap, + mut max_headers_size: Option, + ) -> Result<(), Error> { + loop { + let trailer_line = maybe_await!($read_line(bytes, max_headers_size, Error::HeadersOverflow), $($await)?)?; + if let Some(ref mut max_headers_size) = max_headers_size { + *max_headers_size -= trailer_line.len() + 2; + } + if let Some((header, value)) = parse_header(trailer_line) { + headers.insert(header, value); + } else { + break; + } + } + Ok(()) + } - Ok(ResponseMetadata { - status_code, - reason_phrase, - headers, - state, - max_trailing_headers_size: max_headers_size, - }) -} + $($async)? fn $read_chunked<$($arg: $($argty +)*),*>( + bytes: &mut $stream_type, + headers: &mut BTreeMap, + expecting_more_chunks: &mut bool, + chunk_length: &mut usize, + content_length: &mut usize, + max_trailing_headers_size: Option, + ) -> Option<::Item> { + if !*expecting_more_chunks && *chunk_length == 0 { + return None; + } -#[cfg(feature = "std")] -fn read_line( - stream: &mut HttpStreamBytes, - max_len: Option, - overflow_error: Error, -) -> Result { - let mut bytes = Vec::with_capacity(32); - for byte in stream { - match byte { - Ok(byte) => { - if let Some(max_len) = max_len { - if bytes.len() >= max_len { - return Err(overflow_error); + if *chunk_length == 0 { + // Max length of the chunk length line is 1KB: not too long to + // take up much memory, long enough to tolerate some chunk + // extensions (which are ignored). + + // Get the size of the next chunk + let length_line = match maybe_await!($read_line(bytes, Some(1024), Error::MalformedChunkLength), $($await)?) { + Ok(line) => line, + Err(err) => return Some(Err(err)), + }; + + // Note: the trim() and check for empty lines shouldn't be + // needed according to the RFC, but we might as well, it's a + // small change and it fixes a few servers. + let incoming_length = if length_line.is_empty() { + 0 + } else { + let length = if let Some(i) = length_line.find(';') { + length_line[..i].trim() + } else { + length_line.trim() + }; + match usize::from_str_radix(length, 16) { + Ok(length) => length, + Err(_) => return Some(Err(Error::MalformedChunkLength)), } + }; + + if incoming_length == 0 { + if let Err(err) = maybe_await!($read_trailers(bytes, headers, max_trailing_headers_size), $($await)?) { + return Some(Err(err)); + } + + *expecting_more_chunks = false; + headers.insert("content-length".to_string(), (*content_length).to_string()); + headers.remove("transfer-encoding"); + return None; } - if byte == b'\n' { - if let Some(b'\r') = bytes.last() { - bytes.pop(); + *chunk_length = incoming_length; + *content_length += incoming_length; + } + + if *chunk_length > 0 { + *chunk_length -= 1; + if let Some(byte) = maybe_await!(bytes.next(), $($await)?) { + match byte { + Ok(byte) => { + // If we're at the end of the chunk... + if *chunk_length == 0 { + //...read the trailing \r\n of the chunk, and + // possibly return an error instead. + + // TODO: Maybe this could be written in a way + // that doesn't discard the last ok byte if + // the \r\n reading fails? + if let Err(err) = maybe_await!($read_line(bytes, Some(2), Error::MalformedChunkEnd), $($await)?) { + return Some(Err(err)); + } + } + + return Some(Ok((byte, (*chunk_length).min(MAX_CONTENT_LENGTH) + 1))); + } + Err(err) => return Some(Err(Error::IoError(err))), } + } + } + + None + } + + #[cfg(feature = "std")] + $($async)? fn $read_metadata<$($arg: $($argty +)*),*>( + stream: &mut $stream_type, + mut max_headers_size: Option, + max_status_line_len: Option, + ) -> Result { + let line = maybe_await!($read_line(stream, max_status_line_len, Error::StatusLineOverflow), $($await)?)?; + let (status_code, reason_phrase) = parse_status_line(&line); + + let mut headers = BTreeMap::new(); + loop { + let line = maybe_await!($read_line(stream, max_headers_size, Error::HeadersOverflow), $($await)?)?; + if line.is_empty() { + // Body starts here break; - } else { - bytes.push(byte); + } + if let Some(ref mut max_headers_size) = max_headers_size { + *max_headers_size -= line.len() + 2; + } + if let Some(header) = parse_header(line) { + headers.insert(header.0, header.1); } } - Err(err) => return Err(Error::IoError(err)), + + let mut chunked = false; + let mut content_length = None; + for (header, value) in &headers { + // Handle the Transfer-Encoding header + if header.to_lowercase().trim() == "transfer-encoding" + && value.to_lowercase().trim() == "chunked" + { + chunked = true; + } + + // Handle the Content-Length header + if header.to_lowercase().trim() == "content-length" { + match str::parse::(value.trim()) { + Ok(length) => content_length = Some(length), + Err(_) => return Err(Error::MalformedContentLength), + } + } + } + + let state = if chunked { + HttpStreamState::Chunked(true, 0, 0) + } else if let Some(length) = content_length { + HttpStreamState::ContentLength(length) + } else { + HttpStreamState::EndOnClose + }; + + Ok(ResponseMetadata { + status_code, + reason_phrase, + headers, + state, + max_trailing_headers_size: max_headers_size, + }) + } + + #[cfg(feature = "std")] + $($async)? fn $read_line<$($arg: $($argty +)*),*>( + stream: &mut $stream_type, + max_len: Option, + overflow_error: Error, + ) -> Result { + let mut bytes = Vec::with_capacity(32); + while let Some(byte) = maybe_await!(stream.next(), $($await)?) { + match byte { + Ok(byte) => { + if let Some(max_len) = max_len { + if bytes.len() >= max_len { + return Err(overflow_error); + } + } + if byte == b'\n' { + if let Some(b'\r') = bytes.last() { + bytes.pop(); + } + break; + } else { + bytes.push(byte); + } + } + Err(err) => return Err(Error::IoError(err)), + } + } + String::from_utf8(bytes).map_err(|_error| Error::InvalidUtf8InResponse) } } - String::from_utf8(bytes).map_err(|_error| Error::InvalidUtf8InResponse) } +#[cfg(feature = "std")] +define_read_methods!((read_until_closed, read_with_content_length, read_trailers, read_chunked, read_metadata, read_line)<>, HttpStreamBytes); +#[cfg(feature = "async")] +define_read_methods!((read_until_closed_async, read_with_content_length_async, read_trailers_async, read_chunked_async, read_metadata_async, read_line_async), R, async, await); + #[cfg(feature = "std")] fn parse_status_line(line: &str) -> (i32, String) { // sample status line format diff --git a/bitreq/tests/main.rs b/bitreq/tests/main.rs index 88c40fbc..d32c1d38 100644 --- a/bitreq/tests/main.rs +++ b/bitreq/tests/main.rs @@ -7,16 +7,16 @@ use std::io; use self::setup::*; -#[test] +#[tokio::test] #[cfg(feature = "rustls")] -fn test_https() { +async fn test_https() { // TODO: Implement this locally. - assert_eq!(get_status_code(bitreq::get("https://example.com").send()), 200,); + assert_eq!(get_status_code(bitreq::get("https://example.com")).await, 200); } -#[test] +#[tokio::test] #[cfg(feature = "json-using-serde")] -fn test_json_using_serde() { +async fn test_json_using_serde() { const JSON_SRC: &str = r#"{ "str": "Json test", "num": 42 @@ -24,158 +24,161 @@ fn test_json_using_serde() { setup(); let original_json: serde_json::Value = serde_json::from_str(JSON_SRC).unwrap(); - let response = bitreq::post(url("/echo")).with_json(&original_json).unwrap().send().unwrap(); + let response = + make_request(bitreq::post(url("/echo")).with_json(&original_json).unwrap()).await; let actual_json: serde_json::Value = response.json().unwrap(); assert_eq!(actual_json, original_json); } -#[test] -fn test_timeout_too_low() { +#[tokio::test] +async fn test_timeout_too_low() { setup(); - let result = bitreq::get(url("/slow_a")).with_body("Q".to_string()).with_timeout(1).send(); + let request = bitreq::get(url("/slow_a")).with_body("Q".to_string()).with_timeout(1); + let result = maybe_make_request(request).await; assert!(result.is_err()); } -#[test] -fn test_timeout_high_enough() { +#[tokio::test] +async fn test_timeout_high_enough() { setup(); let body = - get_body(bitreq::get(url("/slow_a")).with_body("Q".to_string()).with_timeout(3).send()); + get_body(bitreq::get(url("/slow_a")).with_body("Q".to_string()).with_timeout(3)).await; assert_eq!(body, "j: Q"); } -#[test] -fn test_headers() { +#[tokio::test] +async fn test_headers() { setup(); - let body = get_body(bitreq::get(url("/header_pong")).with_header("Ping", "Qwerty").send()); + let body = get_body(bitreq::get(url("/header_pong")).with_header("Ping", "Qwerty")).await; assert_eq!("Qwerty", body); } -#[test] -fn test_custom_method() { +#[tokio::test] +async fn test_custom_method() { use bitreq::Method; setup(); - let body = get_body( - bitreq::Request::new(Method::Custom("GET".to_string()), url("/a")).with_body("Q").send(), - ); + let body = + get_body(bitreq::Request::new(Method::Custom("GET".to_string()), url("/a")).with_body("Q")) + .await; assert_eq!("j: Q", body); } -#[test] -fn test_get() { +#[tokio::test] +async fn test_get() { setup(); - let body = get_body(bitreq::get(url("/a")).with_body("Q").send()); + let body = get_body(bitreq::get(url("/a")).with_body("Q")).await; assert_eq!(body, "j: Q"); } -#[test] -fn test_redirect_get() { +#[tokio::test] +async fn test_redirect_get() { setup(); - let body = get_body(bitreq::get(url("/redirect")).with_body("Q").send()); + let body = get_body(bitreq::get(url("/redirect")).with_body("Q")).await; assert_eq!(body, "j: Q"); } -#[test] -fn test_redirect_post() { +#[tokio::test] +async fn test_redirect_post() { setup(); // POSTing to /redirect should return a 303, which means we should // make a GET request to the given location. This test relies on // the fact that the test server only responds to GET requests on // the /a path. - let body = get_body(bitreq::post(url("/redirect")).with_body("Q").send()); + let body = get_body(bitreq::post(url("/redirect")).with_body("Q")).await; assert_eq!(body, "j: Q"); } -#[test] -fn test_redirect_with_fragment() { +#[tokio::test] +async fn test_redirect_with_fragment() { setup(); let original_url = url("/redirect#foo"); - let res = bitreq::get(original_url).send().unwrap(); + let res = make_request(bitreq::get(original_url)).await; // Fragment should stay the same, otherwise redirected assert_eq!(res.url.as_str(), url("/a#foo")); } -#[test] -fn test_redirect_with_overridden_fragment() { +#[tokio::test] +async fn test_redirect_with_overridden_fragment() { setup(); let original_url = url("/redirect-baz#foo"); - let res = bitreq::get(original_url).send().unwrap(); + let res = make_request(bitreq::get(original_url)).await; // This redirect should provide its own fragment, overriding the initial one assert_eq!(res.url.as_str(), url("/a#baz")); } -#[test] -fn test_infinite_redirect() { +#[tokio::test] +async fn test_infinite_redirect() { setup(); - let body = bitreq::get(url("/infiniteredirect")).send(); + let body = maybe_make_request(bitreq::get(url("/infiniteredirect"))).await; assert!(body.is_err()); } -#[test] -fn test_relative_redirect_get() { +#[tokio::test] +async fn test_relative_redirect_get() { setup(); - let body = get_body(bitreq::get(url("/relativeredirect")).with_body("Q").send()); + let body = get_body(bitreq::get(url("/relativeredirect")).with_body("Q")).await; assert_eq!(body, "j: Q"); } -#[test] -fn test_head() { +#[tokio::test] +async fn test_head() { setup(); - assert_eq!(get_status_code(bitreq::head(url("/b")).send()), 418); + assert_eq!(get_status_code(bitreq::head(url("/b"))).await, 418); } -#[test] -fn test_post() { +#[tokio::test] +async fn test_post() { setup(); - let body = get_body(bitreq::post(url("/c")).with_body("E").send()); + let body = get_body(bitreq::post(url("/c")).with_body("E")).await; assert_eq!(body, "l: E"); } -#[test] -fn test_put() { +#[tokio::test] +async fn test_put() { setup(); - let body = get_body(bitreq::put(url("/d")).with_body("R").send()); + let body = get_body(bitreq::put(url("/d")).with_body("R")).await; assert_eq!(body, "m: R"); } -#[test] -fn test_delete() { +#[tokio::test] +async fn test_delete() { setup(); - assert_eq!(get_body(bitreq::delete(url("/e")).send()), "n: "); + assert_eq!(get_body(bitreq::delete(url("/e"))).await, "n: "); } -#[test] -fn test_trace() { +#[tokio::test] +async fn test_trace() { setup(); - assert_eq!(get_body(bitreq::trace(url("/f")).send()), "o: "); + assert_eq!(get_body(bitreq::trace(url("/f"))).await, "o: "); } -#[test] -fn test_options() { +#[tokio::test] +async fn test_options() { setup(); - let body = get_body(bitreq::options(url("/g")).with_body("U").send()); + let body = get_body(bitreq::options(url("/g")).with_body("U")).await; assert_eq!(body, "p: U"); } -#[test] -fn test_connect() { +#[tokio::test] +async fn test_connect() { setup(); - let body = get_body(bitreq::connect(url("/h")).with_body("I").send()); + let body = get_body(bitreq::connect(url("/h")).with_body("I")).await; assert_eq!(body, "q: I"); } -#[test] -fn test_patch() { +#[tokio::test] +async fn test_patch() { setup(); - let body = get_body(bitreq::patch(url("/i")).with_body("O").send()); + let body = get_body(bitreq::patch(url("/i")).with_body("O")).await; assert_eq!(body, "r: O"); } -#[test] -fn tcp_connect_timeout() { +#[tokio::test] +async fn tcp_connect_timeout() { let _listener = std::net::TcpListener::bind("127.0.0.1:32162").unwrap(); - let resp = - bitreq::Request::new(bitreq::Method::Get, "http://127.0.0.1:32162").with_timeout(1).send(); + let request = + bitreq::Request::new(bitreq::Method::Get, "http://127.0.0.1:32162").with_timeout(1); + let resp = maybe_make_request(request).await; assert!(resp.is_err()); if let Some(bitreq::Error::IoError(err)) = resp.err() { assert_eq!(err.kind(), io::ErrorKind::TimedOut); @@ -184,41 +187,41 @@ fn tcp_connect_timeout() { } } -#[test] -fn test_header_cap() { +#[tokio::test] +async fn test_header_cap() { setup(); - let body = bitreq::get(url("/long_header")).with_max_headers_size(999).send(); - assert!(body.is_err()); - assert!(matches!(body.err(), Some(bitreq::Error::HeadersOverflow))); + let res = maybe_make_request(bitreq::get(url("/long_header")).with_max_headers_size(999)).await; + assert!(res.is_err()); + assert!(matches!(res.err(), Some(bitreq::Error::HeadersOverflow))); - let body = bitreq::get(url("/long_header")).with_max_headers_size(1500).send(); - assert!(body.is_ok()); + make_request(bitreq::get(url("/long_header")).with_max_headers_size(1500)).await; } -#[test] -fn test_status_line_cap() { +#[tokio::test] +async fn test_status_line_cap() { setup(); let expected_status_line = "HTTP/1.1 203 Non-Authoritative Information"; - let body = bitreq::get(url("/long_status_line")) - .with_max_status_line_length(expected_status_line.len() + 1) - .send(); - assert!(body.is_err()); - assert!(matches!(body.err(), Some(bitreq::Error::StatusLineOverflow))); + let request = bitreq::get(url("/long_status_line")) + .with_max_status_line_length(expected_status_line.len() + 1); + let resp = maybe_make_request(request).await; + assert!(resp.is_err()); + assert!(matches!(resp.err(), Some(bitreq::Error::StatusLineOverflow))); - let body = bitreq::get(url("/long_status_line")) - .with_max_status_line_length(expected_status_line.len() + 2) - .send(); - assert!(body.is_ok()); + let request = bitreq::get(url("/long_status_line")) + .with_max_status_line_length(expected_status_line.len() + 2); + make_request(request).await; } -#[test] -fn test_massive_content_length() { +#[tokio::test] +async fn test_massive_content_length() { setup(); + #[cfg(feature = "async")] + tokio::spawn(bitreq::get(url("/massive_content_length")).send_async()); std::thread::spawn(|| { // If bitreq trusts Content-Length, this should crash pretty much straight away. let _ = bitreq::get(url("/massive_content_length")).send(); }); - std::thread::sleep(std::time::Duration::from_millis(500)); + tokio::time::sleep(std::time::Duration::from_millis(500)).await; // If it were to crash, it would have at this point. Pass! } diff --git a/bitreq/tests/setup.rs b/bitreq/tests/setup.rs index efde7f97..d0c9575c 100644 --- a/bitreq/tests/setup.rs +++ b/bitreq/tests/setup.rs @@ -1,5 +1,8 @@ +#![cfg(feature = "std")] + extern crate bitreq; extern crate tiny_http; +use std::io::Read; use std::str::FromStr; use std::sync::{Arc, Once}; use std::thread; @@ -12,7 +15,7 @@ static INIT: Once = Once::new(); pub fn setup() { INIT.call_once(|| { let server = Arc::new(Server::http("localhost:35562").unwrap()); - for _ in 0..4 { + for _ in 0..8 { let server = server.clone(); thread::spawn(move || loop { @@ -174,28 +177,68 @@ pub fn setup() { pub fn url(req: &str) -> String { format!("http://localhost:35562{}", req) } -pub fn get_body(request: Result) -> String { - match request { - Ok(response) => match response.as_str() { - Ok(str) => String::from(str), - Err(err) => { - println!("\n[ERROR]: {}\n", err); - String::new() - } - }, - Err(err) => { - println!("\n[ERROR]: {}\n", err); - String::new() +pub async fn maybe_make_request( + request: bitreq::Request, +) -> Result { + let response = request.clone().send(); + let lazy_response = request.clone().send_lazy(); + match (&response, lazy_response) { + (Ok(resp), Ok(mut lazy_resp)) => { + assert_eq!(lazy_resp.status_code, resp.status_code); + assert_eq!(lazy_resp.reason_phrase, resp.reason_phrase); + let mut lazy_bytes = Vec::new(); + lazy_resp.read_to_end(&mut lazy_bytes).unwrap(); + assert_eq!(lazy_bytes, resp.as_bytes()); } + (Err(e), Err(lazy_e)) => assert_eq!(format!("{e:?}"), format!("{lazy_e:?}")), + (res, lazy_res) => panic!("{res:?} != {}", lazy_res.is_err()), } -} -pub fn get_status_code(request: Result) -> i32 { - match request { - Ok(response) => response.status_code, - Err(err) => { - println!("\n[ERROR]: {}\n", err); - -1 + #[cfg(feature = "async")] + { + if let Ok(resp) = &response { + if resp.url.starts_with("https") && !cfg!(feature = "async-https") { + return response; + } + } else { + // Assume its not HTTPS or async-https is set + } + let async_response = request.clone().send_async().await; + let lazy_async_response = request.send_lazy_async().await; + match (&response, &async_response) { + (Ok(resp), Ok(async_resp)) => { + assert_eq!(async_resp.status_code, resp.status_code); + assert_eq!(async_resp.reason_phrase, resp.reason_phrase); + assert_eq!(async_resp.as_bytes(), resp.as_bytes()); + } + (Err(e), Err(async_e)) => assert_eq!(format!("{e:?}"), format!("{async_e:?}")), + (res, async_res) => panic!("{res:?} != {async_res:?}"), + } + match (&response, lazy_async_response) { + (Ok(resp), Ok(mut lazy_resp)) => { + assert_eq!(lazy_resp.status_code, resp.status_code); + assert_eq!(lazy_resp.reason_phrase, resp.reason_phrase); + let mut lazy_bytes = Vec::new(); + lazy_resp.read_to_end(&mut lazy_bytes).unwrap(); + assert_eq!(lazy_bytes, resp.as_bytes()); + } + (Err(e), Err(lazy_e)) => assert_eq!(format!("{e:?}"), format!("{lazy_e:?}")), + (res, lazy_res) => panic!("{res:?} != {}", lazy_res.is_err()), } } + response +} + +pub async fn make_request(request: bitreq::Request) -> bitreq::Response { + maybe_make_request(request).await.unwrap() +} + +pub async fn get_body(request: bitreq::Request) -> String { + let response = make_request(request).await; + String::from(response.as_str().unwrap()) +} + +pub async fn get_status_code(request: bitreq::Request) -> i32 { + let response = make_request(request).await; + response.status_code }