diff --git a/deps/rabbitmq_auth_backend_oauth2/include/oauth2.hrl b/deps/rabbitmq_auth_backend_oauth2/include/oauth2.hrl index f5d0e6559bd5..4652c16ddcd1 100644 --- a/deps/rabbitmq_auth_backend_oauth2/include/oauth2.hrl +++ b/deps/rabbitmq_auth_backend_oauth2/include/oauth2.hrl @@ -22,6 +22,8 @@ %% End of Key JWT fields +-type raw_jwt_token() :: binary() | #{binary() => any()}. +-type decoded_jwt_token() :: #{binary() => any()}. -record(internal_oauth_provider, { id :: oauth_provider_id(), diff --git a/deps/rabbitmq_auth_backend_oauth2/src/rabbit_auth_backend_oauth2.erl b/deps/rabbitmq_auth_backend_oauth2/src/rabbit_auth_backend_oauth2.erl index 849349f04780..6a45d3658e89 100644 --- a/deps/rabbitmq_auth_backend_oauth2/src/rabbit_auth_backend_oauth2.erl +++ b/deps/rabbitmq_auth_backend_oauth2/src/rabbit_auth_backend_oauth2.erl @@ -38,15 +38,11 @@ -endif. %% -%% App environment +%% Types %% - -%% a term defined for Rich Authorization Request tokens to identify a RabbitMQ permission -%% verify server_server_id aud field is on the aud field -%% a term used by the IdentityServer community -%% scope aliases map "role names" to a set of scopes - +-type ok_extracted_auth_user() :: {ok, rabbit_types:auth_user()}. +-type auth_user_extraction_fun() :: fun((decoded_jwt_token()) -> any()). %% %% API @@ -58,6 +54,11 @@ description() -> %%-------------------------------------------------------------------- +-spec user_login_authentication(rabbit_types:username(), [term()] | map()) -> + {'ok', rabbit_types:auth_user()} | + {'refused', string(), [any()]} | + {'error', any()}. + user_login_authentication(Username, AuthProps) -> case authenticate(Username, AuthProps) of {refused, Msg, Args} = AuthResult -> @@ -67,12 +68,21 @@ user_login_authentication(Username, AuthProps) -> AuthResult end. +-spec user_login_authorization(rabbit_types:username(), [term()] | map()) -> + {'ok', any()} | + {'ok', any(), any()} | + {'refused', string(), [any()]} | + {'error', any()}. + user_login_authorization(Username, AuthProps) -> case authenticate(Username, AuthProps) of {ok, #auth_user{impl = Impl}} -> {ok, Impl}; Else -> Else end. +-spec check_vhost_access(AuthUser :: rabbit_types:auth_user(), + VHost :: rabbit_types:vhost(), + AuthzData :: rabbit_types:authz_data()) -> boolean() | {'error', any()}. check_vhost_access(#auth_user{impl = DecodedTokenFun}, VHost, _AuthzData) -> with_decoded_token(DecodedTokenFun(), @@ -136,6 +146,11 @@ expiry_timestamp(#auth_user{impl = DecodedTokenFun}) -> %%-------------------------------------------------------------------- +-spec authenticate(Username, Props) -> Result + when Username :: rabbit_types:username(), + Props :: list() | map(), + Result :: {ok, any()} | {refused, list(), list()} | {refused, {error, any()}}. + authenticate(_, AuthProps0) -> AuthProps = to_map(AuthProps0), Token = token_from_context(AuthProps), @@ -148,17 +163,8 @@ authenticate(_, AuthProps0) -> {refused, "Authentication using an OAuth 2/JWT token failed: provided token is invalid", []}; {refused, Err} -> {refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]}; - {ok, DecodedToken} -> - Func = fun(Token0) -> - Username = username_from( - ResourceServer#resource_server.preferred_username_claims, - Token0), - Tags = tags_from(Token0), - {ok, #auth_user{username = Username, - tags = Tags, - impl = fun() -> Token0 end}} - end, - case with_decoded_token(DecodedToken, Func) of + {ok, DecodedToken} -> + case with_decoded_token(DecodedToken, fun(In) -> auth_user_from_token(In, ResourceServer) end) of {error, Err} -> {refused, "Authentication using an OAuth 2/JWT token failed: ~tp", [Err]}; Else -> @@ -166,6 +172,11 @@ authenticate(_, AuthProps0) -> end end end. + +-spec with_decoded_token(Token, Fun) -> Result + when Token :: decoded_jwt_token(), + Fun :: auth_user_extraction_fun(), + Result :: {ok, any()} | {'error', any()}. with_decoded_token(DecodedToken, Fun) -> case validate_token_expiry(DecodedToken) of ok -> Fun(DecodedToken); @@ -173,6 +184,21 @@ with_decoded_token(DecodedToken, Fun) -> rabbit_log:error(Msg), Err end. + +%% This is a helper function used with HOFs that may return errors. +-spec auth_user_from_token(Token, ResourceServer) -> Result + when Token :: decoded_jwt_token(), + ResourceServer :: resource_server(), + Result :: ok_extracted_auth_user(). +auth_user_from_token(Token0, ResourceServer) -> + Username = username_from( + ResourceServer#resource_server.preferred_username_claims, + Token0), + Tags = tags_from(Token0), + {ok, #auth_user{username = Username, + tags = Tags, + impl = fun() -> Token0 end}}. + ensure_same_username(PreferredUsernameClaims, CurrentDecodedToken, NewDecodedToken) -> CurUsername = username_from(PreferredUsernameClaims, CurrentDecodedToken), case {CurUsername, username_from(PreferredUsernameClaims, NewDecodedToken)} of @@ -188,12 +214,10 @@ validate_token_expiry(#{<<"exp">> := Exp}) when is_integer(Exp) -> end; validate_token_expiry(#{}) -> ok. --spec check_token(binary() | map(), {resource_server(), internal_oauth_provider()}) -> - {'ok', map()} | - {'error', term() }| - {'refused', 'signature_invalid' | - {'error', term()} | - {'invalid_aud', term()}}. +-spec check_token(raw_jwt_token(), {resource_server(), internal_oauth_provider()}) -> + {'ok', decoded_jwt_token()} | + {'error', term() } | + {'refused', 'signature_invalid' | {'error', term()} | {'invalid_aud', term()}}. check_token(DecodedToken, _) when is_map(DecodedToken) -> {ok, DecodedToken}; @@ -206,7 +230,7 @@ check_token(Token, {ResourceServer, InternalOAuthProvider}) -> end. -spec normalize_token_scope( - ResourceServer :: resource_server(), DecodedToken :: map()) -> map(). + ResourceServer :: resource_server(), DecodedToken :: decoded_jwt_token()) -> map(). normalize_token_scope(ResourceServer, Payload) -> Payload0 = maps:map(fun(K, V) -> case K of @@ -395,7 +419,7 @@ resolve_scope_var(Elem, Token, Vhost) -> end) end. --spec tags_from(map()) -> list(atom()). +-spec tags_from(decoded_jwt_token()) -> list(atom()). tags_from(DecodedToken) -> Scopes = maps:get(?SCOPE_JWT_FIELD, DecodedToken, []), TagScopes = filter_matching_scope_prefix_and_drop_it(Scopes, ?TAG_SCOPE_PREFIX),