11//! The htsget authorization middleware.
22//!
33
4- use crate :: HtsGetError ;
54use crate :: error:: Result as HtsGetResult ;
65use crate :: middleware:: error:: Error :: AuthBuilderError ;
76use crate :: middleware:: error:: Result ;
7+ use crate :: { Endpoint , HtsGetError } ;
88use cfg_if:: cfg_if;
99use headers:: authorization:: Bearer ;
1010use headers:: { Authorization , Header } ;
@@ -72,6 +72,7 @@ impl Debug for Auth {
7272}
7373
7474const FORWARD_HEADER_PREFIX : & str = "Htsget-Context-" ;
75+ const ENDPOINT_TYPE_NAME : & str = "Endpoint-Type" ;
7576
7677impl Auth {
7778 /// Get the config for this auth layer instance.
@@ -134,6 +135,7 @@ impl Auth {
134135 & self ,
135136 request_headers : & HeaderMap ,
136137 request_extensions : Option < Value > ,
138+ request_endpoint : & Endpoint ,
137139 ) -> HtsGetResult < HeaderMap > {
138140 let mut forwarded_headers = if self . config . passthrough_auth ( ) {
139141 let auth_header = request_headers
@@ -186,14 +188,20 @@ impl Auth {
186188 } ) ?;
187189
188190 let header_name =
189- HeaderName :: from_str ( & format ! ( "{}{}" , FORWARD_HEADER_PREFIX , extension. name( ) ) )
190- . map_err ( |err| HtsGetError :: InternalError ( err. to_string ( ) ) ) ?;
191- let value = HeaderValue :: from_str ( value)
192- . map_err ( |err| HtsGetError :: InternalError ( err. to_string ( ) ) ) ?;
191+ HeaderName :: from_str ( & format ! ( "{}{}" , FORWARD_HEADER_PREFIX , extension. name( ) ) ) ?;
192+ let value = HeaderValue :: from_str ( value) ?;
193193 forwarded_headers. insert ( header_name, value) ;
194194 }
195195 }
196196
197+ if self . config . forward_endpoint_type ( ) {
198+ let header_name =
199+ HeaderName :: from_str ( & format ! ( "{}{}" , FORWARD_HEADER_PREFIX , ENDPOINT_TYPE_NAME ) ) ?;
200+ let value = HeaderValue :: from_str ( & request_endpoint. to_string ( ) ) ?;
201+
202+ forwarded_headers. insert ( header_name, value) ;
203+ }
204+
197205 Ok ( forwarded_headers)
198206 }
199207
@@ -204,10 +212,12 @@ impl Auth {
204212 & mut self ,
205213 headers : & HeaderMap ,
206214 request_extensions : Option < Value > ,
215+ request_endpoint : & Endpoint ,
207216 ) -> HtsGetResult < Option < AuthorizationRestrictions > > {
208217 match self . config . authorization_url ( ) {
209218 Some ( UrlOrStatic :: Url ( uri) ) => {
210- let forwarded_headers = self . forwarded_headers ( headers, request_extensions) ?;
219+ let forwarded_headers =
220+ self . forwarded_headers ( headers, request_extensions, request_endpoint) ?;
211221
212222 self
213223 . fetch_from_url ( & uri. to_string ( ) , forwarded_headers)
@@ -408,9 +418,10 @@ impl Auth {
408418 path : & str ,
409419 queries : & mut [ Query ] ,
410420 request_extensions : Option < Value > ,
421+ endpoint : & Endpoint ,
411422 ) -> HtsGetResult < Option < AuthorizationRestrictions > > {
412423 let restrictions = self
413- . query_authorization_service ( headers, request_extensions)
424+ . query_authorization_service ( headers, request_extensions, endpoint )
414425 . await ?;
415426
416427 if let Some ( restrictions) = restrictions {
@@ -601,7 +612,9 @@ mod tests {
601612 ( "Custom1" . parse ( ) . unwrap ( ) , "Value" . parse ( ) . unwrap ( ) ) ,
602613 ( "Custom2" . parse ( ) . unwrap ( ) , "Value" . parse ( ) . unwrap ( ) ) ,
603614 ] ) ;
604- let forwarded_headers = result. forwarded_headers ( & request_headers, None ) . unwrap ( ) ;
615+ let forwarded_headers = result
616+ . forwarded_headers ( & request_headers, None , & Endpoint :: Reads )
617+ . unwrap ( ) ;
605618 assert_eq ! (
606619 forwarded_headers,
607620 HeaderMap :: from_iter( [
@@ -624,7 +637,9 @@ mod tests {
624637 . unwrap ( ) ;
625638 let result = AuthBuilder :: default ( ) . with_config ( config) . build ( ) . unwrap ( ) ;
626639
627- let forwarded_headers = result. forwarded_headers ( & request_headers, None ) . unwrap ( ) ;
640+ let forwarded_headers = result
641+ . forwarded_headers ( & request_headers, None , & Endpoint :: Reads )
642+ . unwrap ( ) ;
628643 assert_eq ! (
629644 forwarded_headers,
630645 HeaderMap :: from_iter( [
@@ -648,12 +663,13 @@ mod tests {
648663 let config = builder
649664 . clone ( )
650665 . forward_headers ( vec ! [ "Custom1" . to_string( ) ] )
651- . passthrough_auth ( false )
652666 . build ( )
653667 . unwrap ( ) ;
654668 let result = AuthBuilder :: default ( ) . with_config ( config) . build ( ) . unwrap ( ) ;
655669
656- let forwarded_headers = result. forwarded_headers ( & request_headers, None ) . unwrap ( ) ;
670+ let forwarded_headers = result
671+ . forwarded_headers ( & request_headers, None , & Endpoint :: Reads )
672+ . unwrap ( ) ;
657673 assert_eq ! (
658674 forwarded_headers,
659675 HeaderMap :: from_iter( [ (
@@ -663,11 +679,11 @@ mod tests {
663679 ) ;
664680
665681 let config = builder
682+ . clone ( )
666683 . forward_extensions ( vec ! [ ForwardExtensions :: new(
667684 "$.Key" . to_string( ) ,
668685 "Custom1" . to_string( ) ,
669686 ) ] )
670- . passthrough_auth ( false )
671687 . build ( )
672688 . unwrap ( ) ;
673689 let result = AuthBuilder :: default ( ) . with_config ( config) . build ( ) . unwrap ( ) ;
@@ -678,6 +694,7 @@ mod tests {
678694 Some ( json ! ( {
679695 "Key" : "Value"
680696 } ) ) ,
697+ & Endpoint :: Reads ,
681698 )
682699 . unwrap ( ) ;
683700 assert_eq ! (
@@ -687,6 +704,22 @@ mod tests {
687704 "Value" . parse( ) . unwrap( )
688705 ) , ] )
689706 ) ;
707+
708+ let config = builder. forward_endpoint_type ( true ) . build ( ) . unwrap ( ) ;
709+ let result = AuthBuilder :: default ( ) . with_config ( config) . build ( ) . unwrap ( ) ;
710+
711+ let forwarded_headers = result
712+ . forwarded_headers ( & request_headers, None , & Endpoint :: Variants )
713+ . unwrap ( ) ;
714+ assert_eq ! (
715+ forwarded_headers,
716+ HeaderMap :: from_iter( [ (
717+ format!( "{}{}" , FORWARD_HEADER_PREFIX , ENDPOINT_TYPE_NAME )
718+ . parse( )
719+ . unwrap( ) ,
720+ "variants" . parse( ) . unwrap( )
721+ ) , ] )
722+ ) ;
690723 }
691724
692725 #[ test]
0 commit comments