@@ -12,9 +12,10 @@ use crate::{
1212 ServiceMetrics ,
1313} ;
1414use anyhow:: Context as _;
15- use axum:: async_trait;
15+ use axum:: { async_trait, body :: Body , http :: Request , response :: Response as AxumResponse , Router } ;
1616use fn_error_context:: context;
1717use futures_util:: { stream:: TryStreamExt , FutureExt } ;
18+ use http_body_util:: BodyExt ; // for `collect`
1819use once_cell:: sync:: OnceCell ;
1920use reqwest:: {
2021 blocking:: { Client , ClientBuilder , RequestBuilder , Response } ,
@@ -27,6 +28,7 @@ use std::{
2728} ;
2829use tokio:: runtime:: { Builder , Runtime } ;
2930use tokio:: sync:: oneshot:: Sender ;
31+ use tower:: ServiceExt ;
3032use tracing:: { debug, error, instrument, trace} ;
3133
3234#[ track_caller]
@@ -126,7 +128,6 @@ pub(crate) fn assert_success(path: &str, web: &TestFrontend) -> Result<()> {
126128 assert ! ( status. is_success( ) , "failed to GET {path}: {status}" ) ;
127129 Ok ( ( ) )
128130}
129-
130131/// Make sure that a URL returns a status code between 200-299,
131132/// also check the cache-control headers.
132133pub ( crate ) fn assert_success_cached (
@@ -259,6 +260,96 @@ pub(crate) fn assert_redirect_cached(
259260 Ok ( redirect_response)
260261}
261262
263+ pub ( crate ) trait AxumResponseTestExt {
264+ async fn text ( self ) -> String ;
265+ }
266+
267+ impl AxumResponseTestExt for axum:: response:: Response {
268+ async fn text ( self ) -> String {
269+ String :: from_utf8_lossy ( & self . into_body ( ) . collect ( ) . await . unwrap ( ) . to_bytes ( ) ) . to_string ( )
270+ }
271+ }
272+
273+ pub ( crate ) trait AxumRouterTestExt {
274+ async fn assert_success ( & self , path : & str ) -> Result < ( ) > ;
275+ async fn get ( & self , path : & str ) -> Result < AxumResponse > ;
276+ async fn assert_redirect_common (
277+ & self ,
278+ path : & str ,
279+ expected_target : & str ,
280+ ) -> Result < AxumResponse > ;
281+ async fn assert_redirect ( & self , path : & str , expected_target : & str ) -> Result < AxumResponse > ;
282+ }
283+
284+ impl AxumRouterTestExt for axum:: Router {
285+ /// Make sure that a URL returns a status code between 200-299
286+ async fn assert_success ( & self , path : & str ) -> Result < ( ) > {
287+ let response = self
288+ . clone ( )
289+ . oneshot ( Request :: builder ( ) . uri ( path) . body ( Body :: empty ( ) ) . unwrap ( ) )
290+ . await ?;
291+
292+ let status = response. status ( ) ;
293+ assert ! ( status. is_success( ) , "failed to GET {path}: {status}" ) ;
294+ Ok ( ( ) )
295+ }
296+ /// simple `get` method
297+ async fn get ( & self , path : & str ) -> Result < AxumResponse > {
298+ Ok ( self
299+ . clone ( )
300+ . oneshot ( Request :: builder ( ) . uri ( path) . body ( Body :: empty ( ) ) . unwrap ( ) )
301+ . await ?)
302+ }
303+
304+ async fn assert_redirect_common (
305+ & self ,
306+ path : & str ,
307+ expected_target : & str ,
308+ ) -> Result < AxumResponse > {
309+ let response = self . get ( path) . await ?;
310+ let status = response. status ( ) ;
311+ if !status. is_redirection ( ) {
312+ anyhow:: bail!( "non-redirect from GET {path}: {status}" ) ;
313+ }
314+
315+ let redirect_target = response
316+ . headers ( )
317+ . get ( "Location" )
318+ . context ( "missing 'Location' header" ) ?
319+ . to_str ( )
320+ . context ( "non-ASCII redirect" ) ?;
321+
322+ // FIXME: not sure we need this
323+ // if !expected_target.starts_with("http") {
324+ // // TODO: Should be able to use Url::make_relative,
325+ // // but https://github.com/servo/rust-url/issues/766
326+ // let base = format!("http://{}", web.server_addr());
327+ // redirect_target = redirect_target
328+ // .strip_prefix(&base)
329+ // .unwrap_or(redirect_target);
330+ // }
331+
332+ if redirect_target != expected_target {
333+ anyhow:: bail!( "got redirect to {redirect_target}" ) ;
334+ }
335+
336+ Ok ( response)
337+ }
338+
339+ #[ context( "expected redirect from {path} to {expected_target}" ) ]
340+ async fn assert_redirect ( & self , path : & str , expected_target : & str ) -> Result < AxumResponse > {
341+ let redirect_response = self . assert_redirect_common ( path, expected_target) . await ?;
342+
343+ let response = self . get ( expected_target) . await ?;
344+ let status = response. status ( ) ;
345+ if !status. is_success ( ) {
346+ anyhow:: bail!( "failed to GET {expected_target}: {status}" ) ;
347+ }
348+
349+ Ok ( redirect_response)
350+ }
351+ }
352+
262353pub ( crate ) struct TestEnvironment {
263354 build_queue : OnceCell < Arc < BuildQueue > > ,
264355 async_build_queue : tokio:: sync:: OnceCell < Arc < AsyncBuildQueue > > ,
@@ -534,6 +625,13 @@ impl TestEnvironment {
534625 self . runtime ( ) . block_on ( self . async_fake_release ( ) )
535626 }
536627
628+ pub ( crate ) async fn web_app ( & self ) -> Router {
629+ let template_data = Arc :: new ( TemplateData :: new ( 1 ) . unwrap ( ) ) ;
630+ build_axum_app ( self , template_data)
631+ . await
632+ . expect ( "could not build axum app" )
633+ }
634+
537635 pub ( crate ) async fn async_fake_release ( & self ) -> fakes:: FakeRelease {
538636 fakes:: FakeRelease :: new (
539637 self . async_db ( ) . await ,
@@ -569,6 +667,10 @@ impl Context for TestEnvironment {
569667 Ok ( TestEnvironment :: cdn ( self ) . await )
570668 }
571669
670+ async fn async_pool ( & self ) -> Result < Pool > {
671+ Ok ( self . async_db ( ) . await . pool ( ) )
672+ }
673+
572674 fn pool ( & self ) -> Result < Pool > {
573675 Ok ( self . db ( ) . pool ( ) )
574676 }
@@ -734,10 +836,12 @@ impl TestFrontend {
734836 let ( tx, rx) = tokio:: sync:: oneshot:: channel :: < ( ) > ( ) ;
735837
736838 debug ! ( "building axum app" ) ;
737- let axum_app = build_axum_app ( context, template_data) . expect ( "could not build axum app" ) ;
839+ let runtime = context. runtime ( ) . unwrap ( ) ;
840+ let axum_app = runtime
841+ . block_on ( build_axum_app ( context, template_data) )
842+ . expect ( "could not build axum app" ) ;
738843
739844 let handle = thread:: spawn ( {
740- let runtime = context. runtime ( ) . unwrap ( ) ;
741845 move || {
742846 runtime. block_on ( async {
743847 axum:: serve ( axum_listener, axum_app. into_make_service ( ) )
0 commit comments