From 95f09b7926efb06253b78fdb361e94f34e088a7b Mon Sep 17 00:00:00 2001 From: Lyphion Date: Fri, 23 Aug 2024 19:31:48 +0200 Subject: [PATCH] Implemented basis Token Endpoint fetching Added retrieving of Token Endpoint if only the Issuer is provided to the builder Added some documentation to OAuth2Client and IOAuth2Client Remove synchronous `Build` method. --- .../RabbitMQ.Client.OAuth2/IOAuth2Client.cs | 12 ++ .../RabbitMQ.Client.OAuth2/OAuth2Client.cs | 185 ++++++++++++++---- .../PublicAPI.Shipped.txt | 3 +- .../PublicAPI.Unshipped.txt | 1 + projects/Test/OAuth2/TestOAuth2.cs | 31 +-- projects/Test/OAuth2/TestOAuth2Client.cs | 22 ++- 6 files changed, 193 insertions(+), 61 deletions(-) diff --git a/projects/RabbitMQ.Client.OAuth2/IOAuth2Client.cs b/projects/RabbitMQ.Client.OAuth2/IOAuth2Client.cs index ee1f025a9..b79d0627e 100644 --- a/projects/RabbitMQ.Client.OAuth2/IOAuth2Client.cs +++ b/projects/RabbitMQ.Client.OAuth2/IOAuth2Client.cs @@ -36,7 +36,19 @@ namespace RabbitMQ.Client.OAuth2 { public interface IOAuth2Client { + /// + /// Request a new AccessToken from the Token Endpoint. + /// + /// Cancellation token for this request + /// Token with Access and Refresh Token Task RequestTokenAsync(CancellationToken cancellationToken = default); + + /// + /// Request a new AccessToken using the Refresh Token from the Token Endpoint. + /// + /// Token with the Refresh Token + /// Cancellation token for this request + /// Token with Access and Refresh Token Task RefreshTokenAsync(IToken token, CancellationToken cancellationToken = default); } } diff --git a/projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs b/projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs index d14c0f725..79e293e2a 100644 --- a/projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs +++ b/projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs @@ -42,58 +42,147 @@ namespace RabbitMQ.Client.OAuth2 { public class OAuth2ClientBuilder { + /// + /// Discovery endpoint subpath for all OpenID Connect issuers. + /// + const string DISCOVERY_ENDPOINT = ".well-known/openid-configuration"; + private readonly string _clientId; private readonly string _clientSecret; - private readonly Uri _tokenEndpoint; + + // At least one of the following Uris is not null + private readonly Uri? _tokenEndpoint; + private readonly Uri? _issuer; + private string? _scope; private IDictionary? _additionalRequestParameters; private HttpClientHandler? _httpClientHandler; - public OAuth2ClientBuilder(string clientId, string clientSecret, Uri tokenEndpoint) + /// + /// Create a new builder for creating s. + /// + /// Id of the client + /// Secret of the client + /// Endpoint to receive the Access Token + /// Issuer of the Access Token. Used to automaticly receive the Token Endpoint while building + /// + /// Either or must be provided. + /// + public OAuth2ClientBuilder(string clientId, string clientSecret, Uri? tokenEndpoint = null, Uri? issuer = null) { _clientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); _clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret)); - _tokenEndpoint = tokenEndpoint ?? throw new ArgumentNullException(nameof(tokenEndpoint)); + + if (tokenEndpoint is null && issuer is null) + { + throw new ArgumentException("Either tokenEndpoint or issuer is required"); + } + + _tokenEndpoint = tokenEndpoint; + _issuer = issuer; } + /// + /// Set the requested scopes for the client. + /// + /// OAuth scopes to request from the Issuer public OAuth2ClientBuilder SetScope(string scope) { _scope = scope ?? throw new ArgumentNullException(nameof(scope)); return this; } + /// + /// Set custom HTTP Client handler for requests of the OAuth2 client. + /// + /// Custom handler for HTTP requests public OAuth2ClientBuilder SetHttpClientHandler(HttpClientHandler handler) { _httpClientHandler = handler ?? throw new ArgumentNullException(nameof(handler)); return this; } + /// + /// Add a additional request parameter to each HTTP request. + /// + /// Name of the parameter + /// Value of the parameter public OAuth2ClientBuilder AddRequestParameter(string param, string paramValue) { - if (param == null) + if (param is null) { - throw new ArgumentNullException("param is null"); + throw new ArgumentNullException(nameof(param)); } - if (paramValue == null) + if (paramValue is null) { - throw new ArgumentNullException("paramValue is null"); + throw new ArgumentNullException(nameof(paramValue)); } - if (_additionalRequestParameters == null) - { - _additionalRequestParameters = new Dictionary(); - } + _additionalRequestParameters ??= new Dictionary(); _additionalRequestParameters[param] = paramValue; return this; } - public IOAuth2Client Build() + /// + /// Build the with the provided properties of the builder. + /// + /// Cancellation token for this method + /// Configured OAuth2Client + public async ValueTask BuildAsync(CancellationToken cancellationToken = default) { + // Check if Token Endpoint is missing -> Use Issuer to receive Token Endpoint + if (_tokenEndpoint is null) + { + Uri tokenEndpoint = await GetTokenEndpointFromIssuerAsync(cancellationToken).ConfigureAwait(false); + return new OAuth2Client(_clientId, _clientSecret, tokenEndpoint, + _scope, _additionalRequestParameters, _httpClientHandler); + } + return new OAuth2Client(_clientId, _clientSecret, _tokenEndpoint, _scope, _additionalRequestParameters, _httpClientHandler); } + + /// + /// Receive Token Endpoint from discovery page of the Issuer. + /// + /// Cancellation token for this request + /// Uri of the Token Endpoint + private async Task GetTokenEndpointFromIssuerAsync(CancellationToken cancellationToken = default) + { + if (_issuer is null) + { + throw new InvalidOperationException("The issuer is required"); + } + + using HttpClient httpClient = _httpClientHandler is null + ? new HttpClient() + : new HttpClient(_httpClientHandler, false); + + httpClient.DefaultRequestHeaders.Accept.Clear(); + httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + + // Build endpoint from Issuer and dicovery endpoint, we can't use the Uri overload because the Issuer Uri may not have a trailing '/' + string tempIssuer = _issuer.AbsoluteUri.EndsWith("/") ? _issuer.AbsoluteUri : _issuer.AbsoluteUri + "/"; + Uri discoveryEndpoint = new Uri(tempIssuer + DISCOVERY_ENDPOINT); + + using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Get, discoveryEndpoint); + using HttpResponseMessage response = await httpClient.SendAsync(req, cancellationToken) + .ConfigureAwait(false); + + response.EnsureSuccessStatusCode(); + + OpenIDConnectDiscovery? discovery = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken) + .ConfigureAwait(false); + + if (discovery is null || string.IsNullOrEmpty(discovery.TokenEndpoint)) + { + throw new InvalidOperationException("No token endpoint was found"); + } + + return new Uri(discovery.TokenEndpoint); + } } /** @@ -119,7 +208,7 @@ internal class OAuth2Client : IOAuth2Client, IDisposable public static readonly IDictionary EMPTY = new Dictionary(); - private HttpClient _httpClient; + private readonly HttpClient _httpClient; public OAuth2Client(string clientId, string clientSecret, Uri tokenEndpoint, string? scope, @@ -132,30 +221,26 @@ public OAuth2Client(string clientId, string clientSecret, Uri tokenEndpoint, _additionalRequestParameters = additionalRequestParameters ?? EMPTY; _tokenEndpoint = tokenEndpoint; - if (httpClientHandler is null) - { - _httpClient = new HttpClient(); - } - else - { - _httpClient = new HttpClient(httpClientHandler, false); - } + _httpClient = httpClientHandler is null + ? new HttpClient() + : new HttpClient(httpClientHandler, false); _httpClient.DefaultRequestHeaders.Accept.Clear(); _httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); } + /// public async Task RequestTokenAsync(CancellationToken cancellationToken = default) { using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint); req.Content = new FormUrlEncodedContent(BuildRequestParameters()); - using HttpResponseMessage response = await _httpClient.SendAsync(req) + using HttpResponseMessage response = await _httpClient.SendAsync(req, cancellationToken) .ConfigureAwait(false); response.EnsureSuccessStatusCode(); - JsonToken? token = await response.Content.ReadFromJsonAsync() + JsonToken? token = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken) .ConfigureAwait(false); if (token is null) @@ -163,31 +248,28 @@ public async Task RequestTokenAsync(CancellationToken cancellationToken // TODO specific exception? throw new InvalidOperationException("token is null"); } - else - { - return new Token(token); - } + + return new Token(token); } + /// public async Task RefreshTokenAsync(IToken token, CancellationToken cancellationToken = default) { - if (token.RefreshToken == null) + if (token.RefreshToken is null) { throw new InvalidOperationException("Token has no Refresh Token"); } - using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint) - { - Content = new FormUrlEncodedContent(BuildRefreshParameters(token)) - }; + using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint); + req.Content = new FormUrlEncodedContent(BuildRefreshParameters(token)); - using HttpResponseMessage response = await _httpClient.SendAsync(req) + using HttpResponseMessage response = await _httpClient.SendAsync(req, cancellationToken) .ConfigureAwait(false); response.EnsureSuccessStatusCode(); - JsonToken? refreshedToken = await response.Content.ReadFromJsonAsync() + JsonToken? refreshedToken = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken) .ConfigureAwait(false); if (refreshedToken is null) @@ -195,10 +277,8 @@ public async Task RefreshTokenAsync(IToken token, // TODO specific exception? throw new InvalidOperationException("refreshed token is null"); } - else - { - return new Token(refreshedToken); - } + + return new Token(refreshedToken); } public void Dispose() @@ -214,9 +294,9 @@ private Dictionary BuildRequestParameters() { CLIENT_SECRET, _clientSecret } }; - if (_scope != null && _scope.Length > 0) + if (!string.IsNullOrEmpty(_scope)) { - dict.Add(SCOPE, _scope); + dict.Add(SCOPE, _scope!); } dict.Add(GRANT_TYPE, GRANT_TYPE_CLIENT_CREDENTIALS); @@ -227,8 +307,7 @@ private Dictionary BuildRequestParameters() private Dictionary BuildRefreshParameters(IToken token) { Dictionary dict = BuildRequestParameters(); - dict.Remove(GRANT_TYPE); - dict.Add(GRANT_TYPE, REFRESH_TOKEN); + dict[GRANT_TYPE] = REFRESH_TOKEN; if (_scope != null) { @@ -284,4 +363,26 @@ public long ExpiresIn get; set; } } + + /// + /// Minimal version of the properties of the discovery endpoint. + /// + internal class OpenIDConnectDiscovery + { + public OpenIDConnectDiscovery() + { + TokenEndpoint = string.Empty; + } + + public OpenIDConnectDiscovery(string tokenEndpoint) + { + TokenEndpoint = tokenEndpoint; + } + + [JsonPropertyName("token_endpoint")] + public string TokenEndpoint + { + get; set; + } + } } diff --git a/projects/RabbitMQ.Client.OAuth2/PublicAPI.Shipped.txt b/projects/RabbitMQ.Client.OAuth2/PublicAPI.Shipped.txt index ddd490940..27f02e91c 100644 --- a/projects/RabbitMQ.Client.OAuth2/PublicAPI.Shipped.txt +++ b/projects/RabbitMQ.Client.OAuth2/PublicAPI.Shipped.txt @@ -7,8 +7,7 @@ RabbitMQ.Client.OAuth2.IToken.HasExpired.get -> bool RabbitMQ.Client.OAuth2.IToken.RefreshToken.get -> string RabbitMQ.Client.OAuth2.OAuth2ClientBuilder RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.AddRequestParameter(string param, string paramValue) -> RabbitMQ.Client.OAuth2.OAuth2ClientBuilder -RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.Build() -> RabbitMQ.Client.OAuth2.IOAuth2Client -RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.OAuth2ClientBuilder(string clientId, string clientSecret, System.Uri tokenEndpoint) -> void +RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.OAuth2ClientBuilder(string! clientId, string! clientSecret, System.Uri? tokenEndpoint = null, System.Uri? issuer = null) -> void RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.SetScope(string scope) -> RabbitMQ.Client.OAuth2.OAuth2ClientBuilder RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider.Name.get -> string diff --git a/projects/RabbitMQ.Client.OAuth2/PublicAPI.Unshipped.txt b/projects/RabbitMQ.Client.OAuth2/PublicAPI.Unshipped.txt index f93749f14..9728f46aa 100644 --- a/projects/RabbitMQ.Client.OAuth2/PublicAPI.Unshipped.txt +++ b/projects/RabbitMQ.Client.OAuth2/PublicAPI.Unshipped.txt @@ -10,6 +10,7 @@ RabbitMQ.Client.OAuth2.CredentialsRefresherEventSource.Stopped(string! name) -> RabbitMQ.Client.OAuth2.IOAuth2Client.RefreshTokenAsync(RabbitMQ.Client.OAuth2.IToken! token, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! RabbitMQ.Client.OAuth2.IOAuth2Client.RequestTokenAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! RabbitMQ.Client.OAuth2.NotifyCredentialsRefreshedAsync +RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.BuildAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.SetHttpClientHandler(System.Net.Http.HttpClientHandler! handler) -> RabbitMQ.Client.OAuth2.OAuth2ClientBuilder! RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider.Dispose() -> void RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider.GetCredentialsAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! diff --git a/projects/Test/OAuth2/TestOAuth2.cs b/projects/Test/OAuth2/TestOAuth2.cs index 9f87f90b1..176e4da6b 100644 --- a/projects/Test/OAuth2/TestOAuth2.cs +++ b/projects/Test/OAuth2/TestOAuth2.cs @@ -46,29 +46,33 @@ public class TestOAuth2 : IAsyncLifetime { private const string Exchange = "test_direct"; + private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource(); private readonly SemaphoreSlim _doneEvent = new SemaphoreSlim(0, 1); private readonly ITestOutputHelper _testOutputHelper; - private readonly IConnectionFactory _connectionFactory; private readonly int _tokenExpiresInSeconds; - private readonly OAuth2ClientCredentialsProvider _producerCredentialsProvider; - private readonly OAuth2ClientCredentialsProvider _httpApiCredentialsProvider; - private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource(); + private OAuth2ClientCredentialsProvider? _producerCredentialsProvider; + private OAuth2ClientCredentialsProvider? _httpApiCredentialsProvider; + private IConnectionFactory? _connectionFactory; private IConnection? _connection; private CredentialsRefresher? _credentialsRefresher; public TestOAuth2(ITestOutputHelper testOutputHelper) { _testOutputHelper = testOutputHelper; + _tokenExpiresInSeconds = OAuth2OptionsBase.TokenExpiresInSeconds; + } + public async Task InitializeAsync() + { string modeStr = Environment.GetEnvironmentVariable("OAUTH2_MODE") ?? "uaa"; Mode mode = (Mode)Enum.Parse(typeof(Mode), modeStr.ToLowerInvariant()); var producerOptions = new OAuth2ProducerOptions(mode); - _producerCredentialsProvider = GetCredentialsProvider(producerOptions); + _producerCredentialsProvider = await GetCredentialsProviderAsync(producerOptions); var httpApiOptions = new OAuth2HttpApiOptions(mode); - _httpApiCredentialsProvider = GetCredentialsProvider(httpApiOptions); + _httpApiCredentialsProvider = await GetCredentialsProviderAsync(httpApiOptions); _connectionFactory = new ConnectionFactory { @@ -77,11 +81,6 @@ public TestOAuth2(ITestOutputHelper testOutputHelper) ClientProvidedName = nameof(TestOAuth2) }; - _tokenExpiresInSeconds = OAuth2OptionsBase.TokenExpiresInSeconds; - } - - public async Task InitializeAsync() - { _connection = await _connectionFactory.CreateConnectionAsync(_cancellationTokenSource.Token); _connection.ConnectionShutdown += (sender, ea) => @@ -119,8 +118,9 @@ public async Task DisposeAsync() finally { _doneEvent.Dispose(); - _producerCredentialsProvider.Dispose(); + _producerCredentialsProvider?.Dispose(); _connection?.Dispose(); + _cancellationTokenSource.Dispose(); } } @@ -174,6 +174,7 @@ public async Task IntegrationTest() async Task CloseConnection() { Assert.NotNull(_connection); + Assert.NotNull(_httpApiCredentialsProvider); Credentials httpApiCredentials = await _httpApiCredentialsProvider.GetCredentialsAsync(); closeConnectionUtil = new Util(_testOutputHelper, "mgt_api_client", httpApiCredentials.Password); await closeConnectionUtil.CloseConnectionAsync(_connection.ClientProvidedName); @@ -217,6 +218,7 @@ async Task CloseConnection() [Fact] public async Task SecondConnectionCrashes_GH1429() { + Assert.NotNull(_connectionFactory); // https://github.com/rabbitmq/rabbitmq-dotnet-client/issues/1429 IConnection secondConnection = await _connectionFactory.CreateConnectionAsync(CancellationToken.None); secondConnection.Dispose(); @@ -267,7 +269,7 @@ private async Task ConsumeAsync(IChannel consumeChannel) await consumeChannel.BasicCancelAsync(consumerTag); } - private OAuth2ClientCredentialsProvider GetCredentialsProvider(OAuth2OptionsBase opts) + private async Task GetCredentialsProviderAsync(OAuth2OptionsBase opts) { _testOutputHelper.WriteLine("OAuth2Client "); _testOutputHelper.WriteLine($"- ClientId: {opts.ClientId}"); @@ -276,7 +278,8 @@ private OAuth2ClientCredentialsProvider GetCredentialsProvider(OAuth2OptionsBase _testOutputHelper.WriteLine($"- Scope: {opts.Scope}"); var tokenEndpointUri = new Uri(opts.TokenEndpoint); - IOAuth2Client oAuth2Client = new OAuth2ClientBuilder(opts.ClientId, opts.ClientSecret, tokenEndpointUri).Build(); + var builder = new OAuth2ClientBuilder(opts.ClientId, opts.ClientSecret, tokenEndpointUri); + IOAuth2Client oAuth2Client = await builder.BuildAsync(); return new OAuth2ClientCredentialsProvider(opts.Name, oAuth2Client); } diff --git a/projects/Test/OAuth2/TestOAuth2Client.cs b/projects/Test/OAuth2/TestOAuth2Client.cs index 6b6e50ced..6f02b6136 100644 --- a/projects/Test/OAuth2/TestOAuth2Client.cs +++ b/projects/Test/OAuth2/TestOAuth2Client.cs @@ -41,24 +41,36 @@ namespace OAuth2Test { - public class TestOAuth2Client + public class TestOAuth2Client : IAsyncLifetime { protected string _client_id = "producer"; protected string _client_secret = "kbOFBXI9tANgKUq8vXHLhT6YhbivgXxn"; protected WireMockServer _oauthServer; - protected IOAuth2Client _client; + protected IOAuth2Client? _client; public TestOAuth2Client() { _oauthServer = WireMockServer.Start(); + } + + public async Task InitializeAsync() + { var uri = new Uri(_oauthServer.Url + "/token"); - _client = new OAuth2ClientBuilder(_client_id, _client_secret, uri).Build(); + var builder = new OAuth2ClientBuilder(_client_id, _client_secret, uri); + _client = await builder.BuildAsync(); + } + + public Task DisposeAsync() + { + return Task.CompletedTask; } [Fact] public async Task TestRequestToken() { + Assert.NotNull(_client); + JsonToken expectedJsonToken = new JsonToken("the_access_token", "the_refresh_token", TimeSpan.FromSeconds(10)); ExpectTokenRequest(new RequestFormMatcher() .WithParam("client_id", _client_id) @@ -76,6 +88,8 @@ public async Task TestRequestToken() [Fact] public async Task TestRefreshToken() { + Assert.NotNull(_client); + const string accessToken0 = "the_access_token"; const string accessToken1 = "the_access_token_2"; const string refreshToken = "the_refresh_token"; @@ -110,6 +124,8 @@ public async Task TestRefreshToken() [Fact] public async Task TestInvalidCredentials() { + Assert.NotNull(_client); + _oauthServer .Given( Request.Create()