Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 28 additions & 38 deletions src/tests/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::tests::{
CategoryListResponse, CategoryResponse, CrateList, CrateResponse, GoodCrate, OwnerResp,
OwnersResponse, VersionResponse,
};
use std::future::Future;

use http::{Method, Request};

Expand All @@ -33,6 +34,7 @@ use axum::body::{Body, Bytes};
use axum::extract::connect_info::MockConnectInfo;
use chrono::NaiveDateTime;
use cookie::Cookie;
use futures_util::FutureExt;
use http::header;
use secrecy::ExposeSecret;
use serde_json::json;
Expand Down Expand Up @@ -91,23 +93,28 @@ pub trait RequestHelper {
fn app(&self) -> &TestApp;

/// Run a request that is expected to succeed
async fn run<T>(&self, request: Request<impl Into<Body>>) -> Response<T> {
fn run<T>(&self, request: Request<impl Into<Body>>) -> impl Future<Output = Response<T>> {
let app = self.app();
let router = app.router().clone();
let request = request.map(Into::into);

// Add a mock `SocketAddr` to the requests so that the `ConnectInfo`
// extractor has something to extract.
let mocket_addr = SocketAddr::from(([127, 0, 0, 1], 52381));
let router = router.layer(MockConnectInfo(mocket_addr));
// This inner function is used to avoid long compile times
// due to monomorphization of the `run()` fn itself
async fn inner(app: &TestApp, request: Request<Body>) -> axum::response::Response<Bytes> {
let router = app.router().clone();

let request = request.map(Into::into);
let axum_response = router.oneshot(request).await.unwrap();
// Add a mock `SocketAddr` to the requests so that the `ConnectInfo`
// extractor has something to extract.
let mocket_addr = SocketAddr::from(([127, 0, 0, 1], 52381));
let router = router.layer(MockConnectInfo(mocket_addr));

let axum_response = router.oneshot(request).await.unwrap();

let (parts, body) = axum_response.into_parts();
let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let bytes_response = axum::response::Response::from_parts(parts, bytes);
let (parts, body) = axum_response.into_parts();
let bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
axum::response::Response::from_parts(parts, bytes)
}

Response::new(bytes_response)
inner(app, request).map(Response::new)
}

/// Create a get request
Expand All @@ -134,26 +141,18 @@ pub trait RequestHelper {

/// Issue a PUT request
async fn put<T>(&self, path: &str, body: impl Into<Bytes>) -> Response<T> {
let body = body.into();

let mut request = self.request_builder(Method::PUT, path);
*request.body_mut() = body;
if is_json_body(request.body()) {
request.header(header::CONTENT_TYPE, "application/json");
}
let request = self
.request_builder(Method::PUT, path)
.with_body(body.into());

self.run(request).await
}

/// Issue a PATCH request
async fn patch<T>(&self, path: &str, body: impl Into<Bytes>) -> Response<T> {
let body = body.into();

let mut request = self.request_builder(Method::PATCH, path);
*request.body_mut() = body;
if is_json_body(request.body()) {
request.header(header::CONTENT_TYPE, "application/json");
}
let request = self
.request_builder(Method::PATCH, path)
.with_body(body.into());

self.run(request).await
}
Expand All @@ -166,13 +165,9 @@ pub trait RequestHelper {

/// Issue a DELETE request with a body... yes we do it, for crate owner removal
async fn delete_with_body<T>(&self, path: &str, body: impl Into<Bytes>) -> Response<T> {
let body = body.into();

let mut request = self.request_builder(Method::DELETE, path);
*request.body_mut() = body;
if is_json_body(request.body()) {
request.header(header::CONTENT_TYPE, "application/json");
}
let request = self
.request_builder(Method::DELETE, path)
.with_body(body.into());

self.run(request).await
}
Expand Down Expand Up @@ -256,11 +251,6 @@ fn req(method: Method, path: &str) -> MockRequest {
.unwrap()
}

fn is_json_body(body: &Bytes) -> bool {
(body.starts_with(b"{") && body.ends_with(b"}"))
|| (body.starts_with(b"[") && body.ends_with(b"]"))
}

/// A type that can generate unauthenticated requests
pub struct MockAnonymousUser {
app: TestApp,
Expand Down
17 changes: 16 additions & 1 deletion src/tests/util/mock_request.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use axum::body::Bytes;
use http::{header::IntoHeaderName, HeaderValue, Request};
use http::{header, header::IntoHeaderName, HeaderValue, Request};

pub type MockRequest = Request<Bytes>;

pub trait MockRequestExt {
fn header<K: IntoHeaderName>(&mut self, name: K, value: &str);
fn with_body(self, bytes: Bytes) -> Self;
}

impl MockRequestExt for MockRequest {
Expand All @@ -15,6 +16,20 @@ impl MockRequestExt for MockRequest {
self.headers_mut()
.append(name, HeaderValue::from_str(value).unwrap());
}

fn with_body(mut self, bytes: Bytes) -> Self {
if is_json_body(&bytes) {
self.header(header::CONTENT_TYPE, "application/json");
}

*self.body_mut() = bytes;
self
}
}

fn is_json_body(body: &Bytes) -> bool {
(body.starts_with(b"{") && body.ends_with(b"}"))
|| (body.starts_with(b"[") && body.ends_with(b"]"))
}

#[cfg(test)]
Expand Down
30 changes: 15 additions & 15 deletions src/tests/util/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ impl<T> Response<T> {
assert_ok!(from_utf8(bytes)).to_string()
}

pub fn status(&self) -> StatusCode {
self.response.status()
}

#[track_caller]
pub fn assert_redirect_ends_with(&self, target: &str) -> &Self {
let headers = self.response.headers();
Expand Down Expand Up @@ -105,21 +101,25 @@ fn json<T>(r: &hyper::Response<Bytes>) -> T
where
for<'de> T: serde::Deserialize<'de>,
{
let headers = r.headers();
fn inner(r: &hyper::Response<Bytes>) -> &Bytes {
let headers = r.headers();

assert_some_eq!(headers.get(header::CONTENT_TYPE), "application/json");
assert_some_eq!(headers.get(header::CONTENT_TYPE), "application/json");

let content_length = assert_some!(
r.headers().get(header::CONTENT_LENGTH),
"Missing content-length header"
);
let content_length = assert_ok!(content_length.to_str());
let content_length: usize = assert_ok!(content_length.parse());
let content_length = assert_some!(
r.headers().get(header::CONTENT_LENGTH),
"Missing content-length header"
);
let content_length = assert_ok!(content_length.to_str());
let content_length: usize = assert_ok!(content_length.parse());

let bytes = r.body();
assert_that!(*bytes, len(eq(content_length)));
let bytes = r.body();
assert_that!(*bytes, len(eq(content_length)));

bytes
}

match serde_json::from_slice(bytes) {
match serde_json::from_slice(inner(r)) {
Ok(t) => t,
Err(e) => panic!("failed to decode: {e:?}"),
}
Expand Down