Skip to content

Commit 9607d28

Browse files
authored
Merge pull request #10305 from Turbo87/mono-test-utils
tests/util: Reduce monomorphization overhead for the `RequestHelper` code
2 parents 52f2f0c + 113dcf2 commit 9607d28

File tree

3 files changed

+59
-54
lines changed

3 files changed

+59
-54
lines changed

src/tests/util.rs

Lines changed: 28 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ use crate::tests::{
2424
CategoryListResponse, CategoryResponse, CrateList, CrateResponse, GoodCrate, OwnerResp,
2525
OwnersResponse, VersionResponse,
2626
};
27+
use std::future::Future;
2728

2829
use http::{Method, Request};
2930

@@ -33,6 +34,7 @@ use axum::body::{Body, Bytes};
3334
use axum::extract::connect_info::MockConnectInfo;
3435
use chrono::NaiveDateTime;
3536
use cookie::Cookie;
37+
use futures_util::FutureExt;
3638
use http::header;
3739
use secrecy::ExposeSecret;
3840
use serde_json::json;
@@ -91,23 +93,28 @@ pub trait RequestHelper {
9193
fn app(&self) -> &TestApp;
9294

9395
/// Run a request that is expected to succeed
94-
async fn run<T>(&self, request: Request<impl Into<Body>>) -> Response<T> {
96+
fn run<T>(&self, request: Request<impl Into<Body>>) -> impl Future<Output = Response<T>> {
9597
let app = self.app();
96-
let router = app.router().clone();
98+
let request = request.map(Into::into);
9799

98-
// Add a mock `SocketAddr` to the requests so that the `ConnectInfo`
99-
// extractor has something to extract.
100-
let mocket_addr = SocketAddr::from(([127, 0, 0, 1], 52381));
101-
let router = router.layer(MockConnectInfo(mocket_addr));
100+
// This inner function is used to avoid long compile times
101+
// due to monomorphization of the `run()` fn itself
102+
async fn inner(app: &TestApp, request: Request<Body>) -> axum::response::Response<Bytes> {
103+
let router = app.router().clone();
102104

103-
let request = request.map(Into::into);
104-
let axum_response = router.oneshot(request).await.unwrap();
105+
// Add a mock `SocketAddr` to the requests so that the `ConnectInfo`
106+
// extractor has something to extract.
107+
let mocket_addr = SocketAddr::from(([127, 0, 0, 1], 52381));
108+
let router = router.layer(MockConnectInfo(mocket_addr));
109+
110+
let axum_response = router.oneshot(request).await.unwrap();
105111

106-
let (parts, body) = axum_response.into_parts();
107-
let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
108-
let bytes_response = axum::response::Response::from_parts(parts, bytes);
112+
let (parts, body) = axum_response.into_parts();
113+
let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
114+
axum::response::Response::from_parts(parts, bytes)
115+
}
109116

110-
Response::new(bytes_response)
117+
inner(app, request).map(Response::new)
111118
}
112119

113120
/// Create a get request
@@ -134,26 +141,18 @@ pub trait RequestHelper {
134141

135142
/// Issue a PUT request
136143
async fn put<T>(&self, path: &str, body: impl Into<Bytes>) -> Response<T> {
137-
let body = body.into();
138-
139-
let mut request = self.request_builder(Method::PUT, path);
140-
*request.body_mut() = body;
141-
if is_json_body(request.body()) {
142-
request.header(header::CONTENT_TYPE, "application/json");
143-
}
144+
let request = self
145+
.request_builder(Method::PUT, path)
146+
.with_body(body.into());
144147

145148
self.run(request).await
146149
}
147150

148151
/// Issue a PATCH request
149152
async fn patch<T>(&self, path: &str, body: impl Into<Bytes>) -> Response<T> {
150-
let body = body.into();
151-
152-
let mut request = self.request_builder(Method::PATCH, path);
153-
*request.body_mut() = body;
154-
if is_json_body(request.body()) {
155-
request.header(header::CONTENT_TYPE, "application/json");
156-
}
153+
let request = self
154+
.request_builder(Method::PATCH, path)
155+
.with_body(body.into());
157156

158157
self.run(request).await
159158
}
@@ -166,13 +165,9 @@ pub trait RequestHelper {
166165

167166
/// Issue a DELETE request with a body... yes we do it, for crate owner removal
168167
async fn delete_with_body<T>(&self, path: &str, body: impl Into<Bytes>) -> Response<T> {
169-
let body = body.into();
170-
171-
let mut request = self.request_builder(Method::DELETE, path);
172-
*request.body_mut() = body;
173-
if is_json_body(request.body()) {
174-
request.header(header::CONTENT_TYPE, "application/json");
175-
}
168+
let request = self
169+
.request_builder(Method::DELETE, path)
170+
.with_body(body.into());
176171

177172
self.run(request).await
178173
}
@@ -256,11 +251,6 @@ fn req(method: Method, path: &str) -> MockRequest {
256251
.unwrap()
257252
}
258253

259-
fn is_json_body(body: &Bytes) -> bool {
260-
(body.starts_with(b"{") && body.ends_with(b"}"))
261-
|| (body.starts_with(b"[") && body.ends_with(b"]"))
262-
}
263-
264254
/// A type that can generate unauthenticated requests
265255
pub struct MockAnonymousUser {
266256
app: TestApp,

src/tests/util/mock_request.rs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use axum::body::Bytes;
2-
use http::{header::IntoHeaderName, HeaderValue, Request};
2+
use http::{header, header::IntoHeaderName, HeaderValue, Request};
33

44
pub type MockRequest = Request<Bytes>;
55

66
pub trait MockRequestExt {
77
fn header<K: IntoHeaderName>(&mut self, name: K, value: &str);
8+
fn with_body(self, bytes: Bytes) -> Self;
89
}
910

1011
impl MockRequestExt for MockRequest {
@@ -15,6 +16,20 @@ impl MockRequestExt for MockRequest {
1516
self.headers_mut()
1617
.append(name, HeaderValue::from_str(value).unwrap());
1718
}
19+
20+
fn with_body(mut self, bytes: Bytes) -> Self {
21+
if is_json_body(&bytes) {
22+
self.header(header::CONTENT_TYPE, "application/json");
23+
}
24+
25+
*self.body_mut() = bytes;
26+
self
27+
}
28+
}
29+
30+
fn is_json_body(body: &Bytes) -> bool {
31+
(body.starts_with(b"{") && body.ends_with(b"}"))
32+
|| (body.starts_with(b"[") && body.ends_with(b"]"))
1833
}
1934

2035
#[cfg(test)]

src/tests/util/response.rs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,6 @@ impl<T> Response<T> {
5151
assert_ok!(from_utf8(bytes)).to_string()
5252
}
5353

54-
pub fn status(&self) -> StatusCode {
55-
self.response.status()
56-
}
57-
5854
#[track_caller]
5955
pub fn assert_redirect_ends_with(&self, target: &str) -> &Self {
6056
let headers = self.response.headers();
@@ -105,21 +101,25 @@ fn json<T>(r: &hyper::Response<Bytes>) -> T
105101
where
106102
for<'de> T: serde::Deserialize<'de>,
107103
{
108-
let headers = r.headers();
104+
fn inner(r: &hyper::Response<Bytes>) -> &Bytes {
105+
let headers = r.headers();
109106

110-
assert_some_eq!(headers.get(header::CONTENT_TYPE), "application/json");
107+
assert_some_eq!(headers.get(header::CONTENT_TYPE), "application/json");
111108

112-
let content_length = assert_some!(
113-
r.headers().get(header::CONTENT_LENGTH),
114-
"Missing content-length header"
115-
);
116-
let content_length = assert_ok!(content_length.to_str());
117-
let content_length: usize = assert_ok!(content_length.parse());
109+
let content_length = assert_some!(
110+
r.headers().get(header::CONTENT_LENGTH),
111+
"Missing content-length header"
112+
);
113+
let content_length = assert_ok!(content_length.to_str());
114+
let content_length: usize = assert_ok!(content_length.parse());
118115

119-
let bytes = r.body();
120-
assert_that!(*bytes, len(eq(content_length)));
116+
let bytes = r.body();
117+
assert_that!(*bytes, len(eq(content_length)));
118+
119+
bytes
120+
}
121121

122-
match serde_json::from_slice(bytes) {
122+
match serde_json::from_slice(inner(r)) {
123123
Ok(t) => t,
124124
Err(e) => panic!("failed to decode: {e:?}"),
125125
}

0 commit comments

Comments
 (0)