diff --git a/src/tests/util.rs b/src/tests/util.rs index 8d89762ac9b..7491e9f8578 100644 --- a/src/tests/util.rs +++ b/src/tests/util.rs @@ -24,6 +24,7 @@ use crate::tests::{ CategoryListResponse, CategoryResponse, CrateList, CrateResponse, GoodCrate, OwnerResp, OwnersResponse, VersionResponse, }; +use std::future::Future; use http::{Method, Request}; @@ -33,6 +34,7 @@ use axum::body::{Body, Bytes}; use axum::extract::connect_info::MockConnectInfo; use chrono::NaiveDateTime; use cookie::Cookie; +use futures_util::FutureExt; use http::header; use secrecy::ExposeSecret; use serde_json::json; @@ -91,23 +93,28 @@ pub trait RequestHelper { fn app(&self) -> &TestApp; /// Run a request that is expected to succeed - async fn run(&self, request: Request>) -> Response { + fn run(&self, request: Request>) -> impl Future> { let app = self.app(); - let router = app.router().clone(); + let request = request.map(Into::into); - // Add a mock `SocketAddr` to the requests so that the `ConnectInfo` - // extractor has something to extract. - let mocket_addr = SocketAddr::from(([127, 0, 0, 1], 52381)); - let router = router.layer(MockConnectInfo(mocket_addr)); + // This inner function is used to avoid long compile times + // due to monomorphization of the `run()` fn itself + async fn inner(app: &TestApp, request: Request) -> axum::response::Response { + let router = app.router().clone(); - let request = request.map(Into::into); - let axum_response = router.oneshot(request).await.unwrap(); + // Add a mock `SocketAddr` to the requests so that the `ConnectInfo` + // extractor has something to extract. + let mocket_addr = SocketAddr::from(([127, 0, 0, 1], 52381)); + let router = router.layer(MockConnectInfo(mocket_addr)); + + let axum_response = router.oneshot(request).await.unwrap(); - let (parts, body) = axum_response.into_parts(); - let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); - let bytes_response = axum::response::Response::from_parts(parts, bytes); + let (parts, body) = axum_response.into_parts(); + let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap(); + axum::response::Response::from_parts(parts, bytes) + } - Response::new(bytes_response) + inner(app, request).map(Response::new) } /// Create a get request @@ -134,26 +141,18 @@ pub trait RequestHelper { /// Issue a PUT request async fn put(&self, path: &str, body: impl Into) -> Response { - let body = body.into(); - - let mut request = self.request_builder(Method::PUT, path); - *request.body_mut() = body; - if is_json_body(request.body()) { - request.header(header::CONTENT_TYPE, "application/json"); - } + let request = self + .request_builder(Method::PUT, path) + .with_body(body.into()); self.run(request).await } /// Issue a PATCH request async fn patch(&self, path: &str, body: impl Into) -> Response { - let body = body.into(); - - let mut request = self.request_builder(Method::PATCH, path); - *request.body_mut() = body; - if is_json_body(request.body()) { - request.header(header::CONTENT_TYPE, "application/json"); - } + let request = self + .request_builder(Method::PATCH, path) + .with_body(body.into()); self.run(request).await } @@ -166,13 +165,9 @@ pub trait RequestHelper { /// Issue a DELETE request with a body... yes we do it, for crate owner removal async fn delete_with_body(&self, path: &str, body: impl Into) -> Response { - let body = body.into(); - - let mut request = self.request_builder(Method::DELETE, path); - *request.body_mut() = body; - if is_json_body(request.body()) { - request.header(header::CONTENT_TYPE, "application/json"); - } + let request = self + .request_builder(Method::DELETE, path) + .with_body(body.into()); self.run(request).await } @@ -256,11 +251,6 @@ fn req(method: Method, path: &str) -> MockRequest { .unwrap() } -fn is_json_body(body: &Bytes) -> bool { - (body.starts_with(b"{") && body.ends_with(b"}")) - || (body.starts_with(b"[") && body.ends_with(b"]")) -} - /// A type that can generate unauthenticated requests pub struct MockAnonymousUser { app: TestApp, diff --git a/src/tests/util/mock_request.rs b/src/tests/util/mock_request.rs index 5e0ebf13dc5..077dd086999 100644 --- a/src/tests/util/mock_request.rs +++ b/src/tests/util/mock_request.rs @@ -1,10 +1,11 @@ use axum::body::Bytes; -use http::{header::IntoHeaderName, HeaderValue, Request}; +use http::{header, header::IntoHeaderName, HeaderValue, Request}; pub type MockRequest = Request; pub trait MockRequestExt { fn header(&mut self, name: K, value: &str); + fn with_body(self, bytes: Bytes) -> Self; } impl MockRequestExt for MockRequest { @@ -15,6 +16,20 @@ impl MockRequestExt for MockRequest { self.headers_mut() .append(name, HeaderValue::from_str(value).unwrap()); } + + fn with_body(mut self, bytes: Bytes) -> Self { + if is_json_body(&bytes) { + self.header(header::CONTENT_TYPE, "application/json"); + } + + *self.body_mut() = bytes; + self + } +} + +fn is_json_body(body: &Bytes) -> bool { + (body.starts_with(b"{") && body.ends_with(b"}")) + || (body.starts_with(b"[") && body.ends_with(b"]")) } #[cfg(test)] diff --git a/src/tests/util/response.rs b/src/tests/util/response.rs index 42525d5da13..6dbaea9626e 100644 --- a/src/tests/util/response.rs +++ b/src/tests/util/response.rs @@ -51,10 +51,6 @@ impl Response { assert_ok!(from_utf8(bytes)).to_string() } - pub fn status(&self) -> StatusCode { - self.response.status() - } - #[track_caller] pub fn assert_redirect_ends_with(&self, target: &str) -> &Self { let headers = self.response.headers(); @@ -105,21 +101,25 @@ fn json(r: &hyper::Response) -> T where for<'de> T: serde::Deserialize<'de>, { - let headers = r.headers(); + fn inner(r: &hyper::Response) -> &Bytes { + let headers = r.headers(); - assert_some_eq!(headers.get(header::CONTENT_TYPE), "application/json"); + assert_some_eq!(headers.get(header::CONTENT_TYPE), "application/json"); - let content_length = assert_some!( - r.headers().get(header::CONTENT_LENGTH), - "Missing content-length header" - ); - let content_length = assert_ok!(content_length.to_str()); - let content_length: usize = assert_ok!(content_length.parse()); + let content_length = assert_some!( + r.headers().get(header::CONTENT_LENGTH), + "Missing content-length header" + ); + let content_length = assert_ok!(content_length.to_str()); + let content_length: usize = assert_ok!(content_length.parse()); - let bytes = r.body(); - assert_that!(*bytes, len(eq(content_length))); + let bytes = r.body(); + assert_that!(*bytes, len(eq(content_length))); + + bytes + } - match serde_json::from_slice(bytes) { + match serde_json::from_slice(inner(r)) { Ok(t) => t, Err(e) => panic!("failed to decode: {e:?}"), }