diff --git a/projects/RabbitMQ.Client.OAuth2/IOAuth2Client.cs b/projects/RabbitMQ.Client.OAuth2/IOAuth2Client.cs
index ee1f025a9..b79d0627e 100644
--- a/projects/RabbitMQ.Client.OAuth2/IOAuth2Client.cs
+++ b/projects/RabbitMQ.Client.OAuth2/IOAuth2Client.cs
@@ -36,7 +36,19 @@ namespace RabbitMQ.Client.OAuth2
{
public interface IOAuth2Client
{
+ ///
+ /// Request a new AccessToken from the Token Endpoint.
+ ///
+ /// Cancellation token for this request
+ /// Token with Access and Refresh Token
Task RequestTokenAsync(CancellationToken cancellationToken = default);
+
+ ///
+ /// Request a new AccessToken using the Refresh Token from the Token Endpoint.
+ ///
+ /// Token with the Refresh Token
+ /// Cancellation token for this request
+ /// Token with Access and Refresh Token
Task RefreshTokenAsync(IToken token, CancellationToken cancellationToken = default);
}
}
diff --git a/projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs b/projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs
index d14c0f725..79e293e2a 100644
--- a/projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs
+++ b/projects/RabbitMQ.Client.OAuth2/OAuth2Client.cs
@@ -42,58 +42,147 @@ namespace RabbitMQ.Client.OAuth2
{
public class OAuth2ClientBuilder
{
+ ///
+ /// Discovery endpoint subpath for all OpenID Connect issuers.
+ ///
+ const string DISCOVERY_ENDPOINT = ".well-known/openid-configuration";
+
private readonly string _clientId;
private readonly string _clientSecret;
- private readonly Uri _tokenEndpoint;
+
+ // At least one of the following Uris is not null
+ private readonly Uri? _tokenEndpoint;
+ private readonly Uri? _issuer;
+
private string? _scope;
private IDictionary? _additionalRequestParameters;
private HttpClientHandler? _httpClientHandler;
- public OAuth2ClientBuilder(string clientId, string clientSecret, Uri tokenEndpoint)
+ ///
+ /// Create a new builder for creating s.
+ ///
+ /// Id of the client
+ /// Secret of the client
+ /// Endpoint to receive the Access Token
+ /// Issuer of the Access Token. Used to automaticly receive the Token Endpoint while building
+ ///
+ /// Either or must be provided.
+ ///
+ public OAuth2ClientBuilder(string clientId, string clientSecret, Uri? tokenEndpoint = null, Uri? issuer = null)
{
_clientId = clientId ?? throw new ArgumentNullException(nameof(clientId));
_clientSecret = clientSecret ?? throw new ArgumentNullException(nameof(clientSecret));
- _tokenEndpoint = tokenEndpoint ?? throw new ArgumentNullException(nameof(tokenEndpoint));
+
+ if (tokenEndpoint is null && issuer is null)
+ {
+ throw new ArgumentException("Either tokenEndpoint or issuer is required");
+ }
+
+ _tokenEndpoint = tokenEndpoint;
+ _issuer = issuer;
}
+ ///
+ /// Set the requested scopes for the client.
+ ///
+ /// OAuth scopes to request from the Issuer
public OAuth2ClientBuilder SetScope(string scope)
{
_scope = scope ?? throw new ArgumentNullException(nameof(scope));
return this;
}
+ ///
+ /// Set custom HTTP Client handler for requests of the OAuth2 client.
+ ///
+ /// Custom handler for HTTP requests
public OAuth2ClientBuilder SetHttpClientHandler(HttpClientHandler handler)
{
_httpClientHandler = handler ?? throw new ArgumentNullException(nameof(handler));
return this;
}
+ ///
+ /// Add a additional request parameter to each HTTP request.
+ ///
+ /// Name of the parameter
+ /// Value of the parameter
public OAuth2ClientBuilder AddRequestParameter(string param, string paramValue)
{
- if (param == null)
+ if (param is null)
{
- throw new ArgumentNullException("param is null");
+ throw new ArgumentNullException(nameof(param));
}
- if (paramValue == null)
+ if (paramValue is null)
{
- throw new ArgumentNullException("paramValue is null");
+ throw new ArgumentNullException(nameof(paramValue));
}
- if (_additionalRequestParameters == null)
- {
- _additionalRequestParameters = new Dictionary();
- }
+ _additionalRequestParameters ??= new Dictionary();
_additionalRequestParameters[param] = paramValue;
return this;
}
- public IOAuth2Client Build()
+ ///
+ /// Build the with the provided properties of the builder.
+ ///
+ /// Cancellation token for this method
+ /// Configured OAuth2Client
+ public async ValueTask BuildAsync(CancellationToken cancellationToken = default)
{
+ // Check if Token Endpoint is missing -> Use Issuer to receive Token Endpoint
+ if (_tokenEndpoint is null)
+ {
+ Uri tokenEndpoint = await GetTokenEndpointFromIssuerAsync(cancellationToken).ConfigureAwait(false);
+ return new OAuth2Client(_clientId, _clientSecret, tokenEndpoint,
+ _scope, _additionalRequestParameters, _httpClientHandler);
+ }
+
return new OAuth2Client(_clientId, _clientSecret, _tokenEndpoint,
_scope, _additionalRequestParameters, _httpClientHandler);
}
+
+ ///
+ /// Receive Token Endpoint from discovery page of the Issuer.
+ ///
+ /// Cancellation token for this request
+ /// Uri of the Token Endpoint
+ private async Task GetTokenEndpointFromIssuerAsync(CancellationToken cancellationToken = default)
+ {
+ if (_issuer is null)
+ {
+ throw new InvalidOperationException("The issuer is required");
+ }
+
+ using HttpClient httpClient = _httpClientHandler is null
+ ? new HttpClient()
+ : new HttpClient(_httpClientHandler, false);
+
+ httpClient.DefaultRequestHeaders.Accept.Clear();
+ httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
+
+ // Build endpoint from Issuer and dicovery endpoint, we can't use the Uri overload because the Issuer Uri may not have a trailing '/'
+ string tempIssuer = _issuer.AbsoluteUri.EndsWith("/") ? _issuer.AbsoluteUri : _issuer.AbsoluteUri + "/";
+ Uri discoveryEndpoint = new Uri(tempIssuer + DISCOVERY_ENDPOINT);
+
+ using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Get, discoveryEndpoint);
+ using HttpResponseMessage response = await httpClient.SendAsync(req, cancellationToken)
+ .ConfigureAwait(false);
+
+ response.EnsureSuccessStatusCode();
+
+ OpenIDConnectDiscovery? discovery = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken)
+ .ConfigureAwait(false);
+
+ if (discovery is null || string.IsNullOrEmpty(discovery.TokenEndpoint))
+ {
+ throw new InvalidOperationException("No token endpoint was found");
+ }
+
+ return new Uri(discovery.TokenEndpoint);
+ }
}
/**
@@ -119,7 +208,7 @@ internal class OAuth2Client : IOAuth2Client, IDisposable
public static readonly IDictionary EMPTY = new Dictionary();
- private HttpClient _httpClient;
+ private readonly HttpClient _httpClient;
public OAuth2Client(string clientId, string clientSecret, Uri tokenEndpoint,
string? scope,
@@ -132,30 +221,26 @@ public OAuth2Client(string clientId, string clientSecret, Uri tokenEndpoint,
_additionalRequestParameters = additionalRequestParameters ?? EMPTY;
_tokenEndpoint = tokenEndpoint;
- if (httpClientHandler is null)
- {
- _httpClient = new HttpClient();
- }
- else
- {
- _httpClient = new HttpClient(httpClientHandler, false);
- }
+ _httpClient = httpClientHandler is null
+ ? new HttpClient()
+ : new HttpClient(httpClientHandler, false);
_httpClient.DefaultRequestHeaders.Accept.Clear();
_httpClient.DefaultRequestHeaders.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
}
+ ///
public async Task RequestTokenAsync(CancellationToken cancellationToken = default)
{
using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint);
req.Content = new FormUrlEncodedContent(BuildRequestParameters());
- using HttpResponseMessage response = await _httpClient.SendAsync(req)
+ using HttpResponseMessage response = await _httpClient.SendAsync(req, cancellationToken)
.ConfigureAwait(false);
response.EnsureSuccessStatusCode();
- JsonToken? token = await response.Content.ReadFromJsonAsync()
+ JsonToken? token = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken)
.ConfigureAwait(false);
if (token is null)
@@ -163,31 +248,28 @@ public async Task RequestTokenAsync(CancellationToken cancellationToken
// TODO specific exception?
throw new InvalidOperationException("token is null");
}
- else
- {
- return new Token(token);
- }
+
+ return new Token(token);
}
+ ///
public async Task RefreshTokenAsync(IToken token,
CancellationToken cancellationToken = default)
{
- if (token.RefreshToken == null)
+ if (token.RefreshToken is null)
{
throw new InvalidOperationException("Token has no Refresh Token");
}
- using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint)
- {
- Content = new FormUrlEncodedContent(BuildRefreshParameters(token))
- };
+ using HttpRequestMessage req = new HttpRequestMessage(HttpMethod.Post, _tokenEndpoint);
+ req.Content = new FormUrlEncodedContent(BuildRefreshParameters(token));
- using HttpResponseMessage response = await _httpClient.SendAsync(req)
+ using HttpResponseMessage response = await _httpClient.SendAsync(req, cancellationToken)
.ConfigureAwait(false);
response.EnsureSuccessStatusCode();
- JsonToken? refreshedToken = await response.Content.ReadFromJsonAsync()
+ JsonToken? refreshedToken = await response.Content.ReadFromJsonAsync(cancellationToken: cancellationToken)
.ConfigureAwait(false);
if (refreshedToken is null)
@@ -195,10 +277,8 @@ public async Task RefreshTokenAsync(IToken token,
// TODO specific exception?
throw new InvalidOperationException("refreshed token is null");
}
- else
- {
- return new Token(refreshedToken);
- }
+
+ return new Token(refreshedToken);
}
public void Dispose()
@@ -214,9 +294,9 @@ private Dictionary BuildRequestParameters()
{ CLIENT_SECRET, _clientSecret }
};
- if (_scope != null && _scope.Length > 0)
+ if (!string.IsNullOrEmpty(_scope))
{
- dict.Add(SCOPE, _scope);
+ dict.Add(SCOPE, _scope!);
}
dict.Add(GRANT_TYPE, GRANT_TYPE_CLIENT_CREDENTIALS);
@@ -227,8 +307,7 @@ private Dictionary BuildRequestParameters()
private Dictionary BuildRefreshParameters(IToken token)
{
Dictionary dict = BuildRequestParameters();
- dict.Remove(GRANT_TYPE);
- dict.Add(GRANT_TYPE, REFRESH_TOKEN);
+ dict[GRANT_TYPE] = REFRESH_TOKEN;
if (_scope != null)
{
@@ -284,4 +363,26 @@ public long ExpiresIn
get; set;
}
}
+
+ ///
+ /// Minimal version of the properties of the discovery endpoint.
+ ///
+ internal class OpenIDConnectDiscovery
+ {
+ public OpenIDConnectDiscovery()
+ {
+ TokenEndpoint = string.Empty;
+ }
+
+ public OpenIDConnectDiscovery(string tokenEndpoint)
+ {
+ TokenEndpoint = tokenEndpoint;
+ }
+
+ [JsonPropertyName("token_endpoint")]
+ public string TokenEndpoint
+ {
+ get; set;
+ }
+ }
}
diff --git a/projects/RabbitMQ.Client.OAuth2/PublicAPI.Shipped.txt b/projects/RabbitMQ.Client.OAuth2/PublicAPI.Shipped.txt
index ddd490940..27f02e91c 100644
--- a/projects/RabbitMQ.Client.OAuth2/PublicAPI.Shipped.txt
+++ b/projects/RabbitMQ.Client.OAuth2/PublicAPI.Shipped.txt
@@ -7,8 +7,7 @@ RabbitMQ.Client.OAuth2.IToken.HasExpired.get -> bool
RabbitMQ.Client.OAuth2.IToken.RefreshToken.get -> string
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.AddRequestParameter(string param, string paramValue) -> RabbitMQ.Client.OAuth2.OAuth2ClientBuilder
-RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.Build() -> RabbitMQ.Client.OAuth2.IOAuth2Client
-RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.OAuth2ClientBuilder(string clientId, string clientSecret, System.Uri tokenEndpoint) -> void
+RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.OAuth2ClientBuilder(string! clientId, string! clientSecret, System.Uri? tokenEndpoint = null, System.Uri? issuer = null) -> void
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.SetScope(string scope) -> RabbitMQ.Client.OAuth2.OAuth2ClientBuilder
RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider
RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider.Name.get -> string
diff --git a/projects/RabbitMQ.Client.OAuth2/PublicAPI.Unshipped.txt b/projects/RabbitMQ.Client.OAuth2/PublicAPI.Unshipped.txt
index f93749f14..9728f46aa 100644
--- a/projects/RabbitMQ.Client.OAuth2/PublicAPI.Unshipped.txt
+++ b/projects/RabbitMQ.Client.OAuth2/PublicAPI.Unshipped.txt
@@ -10,6 +10,7 @@ RabbitMQ.Client.OAuth2.CredentialsRefresherEventSource.Stopped(string! name) ->
RabbitMQ.Client.OAuth2.IOAuth2Client.RefreshTokenAsync(RabbitMQ.Client.OAuth2.IToken! token, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
RabbitMQ.Client.OAuth2.IOAuth2Client.RequestTokenAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
RabbitMQ.Client.OAuth2.NotifyCredentialsRefreshedAsync
+RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.BuildAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.ValueTask
RabbitMQ.Client.OAuth2.OAuth2ClientBuilder.SetHttpClientHandler(System.Net.Http.HttpClientHandler! handler) -> RabbitMQ.Client.OAuth2.OAuth2ClientBuilder!
RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider.Dispose() -> void
RabbitMQ.Client.OAuth2.OAuth2ClientCredentialsProvider.GetCredentialsAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task!
diff --git a/projects/Test/OAuth2/TestOAuth2.cs b/projects/Test/OAuth2/TestOAuth2.cs
index 9f87f90b1..176e4da6b 100644
--- a/projects/Test/OAuth2/TestOAuth2.cs
+++ b/projects/Test/OAuth2/TestOAuth2.cs
@@ -46,29 +46,33 @@ public class TestOAuth2 : IAsyncLifetime
{
private const string Exchange = "test_direct";
+ private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource();
private readonly SemaphoreSlim _doneEvent = new SemaphoreSlim(0, 1);
private readonly ITestOutputHelper _testOutputHelper;
- private readonly IConnectionFactory _connectionFactory;
private readonly int _tokenExpiresInSeconds;
- private readonly OAuth2ClientCredentialsProvider _producerCredentialsProvider;
- private readonly OAuth2ClientCredentialsProvider _httpApiCredentialsProvider;
- private readonly CancellationTokenSource _cancellationTokenSource = new CancellationTokenSource();
+ private OAuth2ClientCredentialsProvider? _producerCredentialsProvider;
+ private OAuth2ClientCredentialsProvider? _httpApiCredentialsProvider;
+ private IConnectionFactory? _connectionFactory;
private IConnection? _connection;
private CredentialsRefresher? _credentialsRefresher;
public TestOAuth2(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
+ _tokenExpiresInSeconds = OAuth2OptionsBase.TokenExpiresInSeconds;
+ }
+ public async Task InitializeAsync()
+ {
string modeStr = Environment.GetEnvironmentVariable("OAUTH2_MODE") ?? "uaa";
Mode mode = (Mode)Enum.Parse(typeof(Mode), modeStr.ToLowerInvariant());
var producerOptions = new OAuth2ProducerOptions(mode);
- _producerCredentialsProvider = GetCredentialsProvider(producerOptions);
+ _producerCredentialsProvider = await GetCredentialsProviderAsync(producerOptions);
var httpApiOptions = new OAuth2HttpApiOptions(mode);
- _httpApiCredentialsProvider = GetCredentialsProvider(httpApiOptions);
+ _httpApiCredentialsProvider = await GetCredentialsProviderAsync(httpApiOptions);
_connectionFactory = new ConnectionFactory
{
@@ -77,11 +81,6 @@ public TestOAuth2(ITestOutputHelper testOutputHelper)
ClientProvidedName = nameof(TestOAuth2)
};
- _tokenExpiresInSeconds = OAuth2OptionsBase.TokenExpiresInSeconds;
- }
-
- public async Task InitializeAsync()
- {
_connection = await _connectionFactory.CreateConnectionAsync(_cancellationTokenSource.Token);
_connection.ConnectionShutdown += (sender, ea) =>
@@ -119,8 +118,9 @@ public async Task DisposeAsync()
finally
{
_doneEvent.Dispose();
- _producerCredentialsProvider.Dispose();
+ _producerCredentialsProvider?.Dispose();
_connection?.Dispose();
+ _cancellationTokenSource.Dispose();
}
}
@@ -174,6 +174,7 @@ public async Task IntegrationTest()
async Task CloseConnection()
{
Assert.NotNull(_connection);
+ Assert.NotNull(_httpApiCredentialsProvider);
Credentials httpApiCredentials = await _httpApiCredentialsProvider.GetCredentialsAsync();
closeConnectionUtil = new Util(_testOutputHelper, "mgt_api_client", httpApiCredentials.Password);
await closeConnectionUtil.CloseConnectionAsync(_connection.ClientProvidedName);
@@ -217,6 +218,7 @@ async Task CloseConnection()
[Fact]
public async Task SecondConnectionCrashes_GH1429()
{
+ Assert.NotNull(_connectionFactory);
// https://github.com/rabbitmq/rabbitmq-dotnet-client/issues/1429
IConnection secondConnection = await _connectionFactory.CreateConnectionAsync(CancellationToken.None);
secondConnection.Dispose();
@@ -267,7 +269,7 @@ private async Task ConsumeAsync(IChannel consumeChannel)
await consumeChannel.BasicCancelAsync(consumerTag);
}
- private OAuth2ClientCredentialsProvider GetCredentialsProvider(OAuth2OptionsBase opts)
+ private async Task GetCredentialsProviderAsync(OAuth2OptionsBase opts)
{
_testOutputHelper.WriteLine("OAuth2Client ");
_testOutputHelper.WriteLine($"- ClientId: {opts.ClientId}");
@@ -276,7 +278,8 @@ private OAuth2ClientCredentialsProvider GetCredentialsProvider(OAuth2OptionsBase
_testOutputHelper.WriteLine($"- Scope: {opts.Scope}");
var tokenEndpointUri = new Uri(opts.TokenEndpoint);
- IOAuth2Client oAuth2Client = new OAuth2ClientBuilder(opts.ClientId, opts.ClientSecret, tokenEndpointUri).Build();
+ var builder = new OAuth2ClientBuilder(opts.ClientId, opts.ClientSecret, tokenEndpointUri);
+ IOAuth2Client oAuth2Client = await builder.BuildAsync();
return new OAuth2ClientCredentialsProvider(opts.Name, oAuth2Client);
}
diff --git a/projects/Test/OAuth2/TestOAuth2Client.cs b/projects/Test/OAuth2/TestOAuth2Client.cs
index 6b6e50ced..6f02b6136 100644
--- a/projects/Test/OAuth2/TestOAuth2Client.cs
+++ b/projects/Test/OAuth2/TestOAuth2Client.cs
@@ -41,24 +41,36 @@
namespace OAuth2Test
{
- public class TestOAuth2Client
+ public class TestOAuth2Client : IAsyncLifetime
{
protected string _client_id = "producer";
protected string _client_secret = "kbOFBXI9tANgKUq8vXHLhT6YhbivgXxn";
protected WireMockServer _oauthServer;
- protected IOAuth2Client _client;
+ protected IOAuth2Client? _client;
public TestOAuth2Client()
{
_oauthServer = WireMockServer.Start();
+ }
+
+ public async Task InitializeAsync()
+ {
var uri = new Uri(_oauthServer.Url + "/token");
- _client = new OAuth2ClientBuilder(_client_id, _client_secret, uri).Build();
+ var builder = new OAuth2ClientBuilder(_client_id, _client_secret, uri);
+ _client = await builder.BuildAsync();
+ }
+
+ public Task DisposeAsync()
+ {
+ return Task.CompletedTask;
}
[Fact]
public async Task TestRequestToken()
{
+ Assert.NotNull(_client);
+
JsonToken expectedJsonToken = new JsonToken("the_access_token", "the_refresh_token", TimeSpan.FromSeconds(10));
ExpectTokenRequest(new RequestFormMatcher()
.WithParam("client_id", _client_id)
@@ -76,6 +88,8 @@ public async Task TestRequestToken()
[Fact]
public async Task TestRefreshToken()
{
+ Assert.NotNull(_client);
+
const string accessToken0 = "the_access_token";
const string accessToken1 = "the_access_token_2";
const string refreshToken = "the_refresh_token";
@@ -110,6 +124,8 @@ public async Task TestRefreshToken()
[Fact]
public async Task TestInvalidCredentials()
{
+ Assert.NotNull(_client);
+
_oauthServer
.Given(
Request.Create()