diff --git a/projects/RabbitMQ.Client/client/api/ICredentialsRefresher.cs b/projects/RabbitMQ.Client/client/api/ICredentialsRefresher.cs index ea16d26b1f..ab68e0af1c 100644 --- a/projects/RabbitMQ.Client/client/api/ICredentialsRefresher.cs +++ b/projects/RabbitMQ.Client/client/api/ICredentialsRefresher.cs @@ -30,10 +30,12 @@ //--------------------------------------------------------------------------- using System; -using System.Collections.Concurrent; +using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.Diagnostics.Tracing; using System.Threading.Tasks; +using System.Timers; + namespace RabbitMQ.Client { public interface ICredentialsRefresher @@ -71,7 +73,9 @@ public class TimerBasedCredentialRefresherEventSource : EventSource public class TimerBasedCredentialRefresher : ICredentialsRefresher { - private readonly ConcurrentDictionary _registrations = new ConcurrentDictionary(); + private readonly IDictionary _registrations = + new Dictionary(); + private readonly object _lockObj = new(); public ICredentialsProvider Register(ICredentialsProvider provider, ICredentialsRefresher.NotifyCredentialRefreshedAsync callback) { @@ -80,65 +84,117 @@ public ICredentialsProvider Register(ICredentialsProvider provider, ICredentials return provider; } - if (_registrations.TryAdd(provider, scheduleTimer(provider, callback))) + lock (_lockObj) { + if (_registrations.TryGetValue(provider, out var registration)) + { + registration.Callback = callback; + TimerBasedCredentialRefresherEventSource.Log.AlreadyRegistered(provider.Name); + return provider; + } + + registration = new TimerRegistration(callback); + _registrations.Add(provider, registration); + registration.ScheduleTimer(provider); + TimerBasedCredentialRefresherEventSource.Log.Registered(provider.Name); } - else - { - TimerBasedCredentialRefresherEventSource.Log.AlreadyRegistered(provider.Name); - } return provider; } public bool Unregister(ICredentialsProvider provider) { - if (_registrations.TryRemove(provider, out System.Timers.Timer? timer)) + lock (_lockObj) { - try + if (_registrations.Remove(provider, out var registration)) { TimerBasedCredentialRefresherEventSource.Log.Unregistered(provider.Name); - timer.Stop(); + registration.Dispose(); + return true; } - finally - { - timer.Dispose(); - } - return true; - } - else - { - return false; } + + return false; } - private System.Timers.Timer scheduleTimer(ICredentialsProvider provider, ICredentialsRefresher.NotifyCredentialRefreshedAsync callback) + private class TimerRegistration : IDisposable { - System.Timers.Timer timer = new System.Timers.Timer(); - timer.Interval = provider.ValidUntil!.Value.TotalMilliseconds * (1.0 - (1 / 3.0)); - timer.Elapsed += (o, e) => + + private System.Timers.Timer? _timer; + private bool _disposed; + + public ICredentialsRefresher.NotifyCredentialRefreshedAsync Callback { get; set; } + + public TimerRegistration(ICredentialsRefresher.NotifyCredentialRefreshedAsync callback) { - TimerBasedCredentialRefresherEventSource.Log.TriggeredTimer(provider.Name); + Callback = callback; + } + + public void ScheduleTimer(ICredentialsProvider provider) + { + if (provider.ValidUntil == null) + { + throw new ArgumentNullException(nameof(provider.ValidUntil) + " of " + provider.GetType().Name + " was null"); + } + if (_disposed) + { + return; + } + + var newTimer = new Timer(); + newTimer.Interval = provider.ValidUntil.Value.TotalMilliseconds * (1.0 - 1 / 3.0); + newTimer.Elapsed += async (o, e) => + { + TimerBasedCredentialRefresherEventSource.Log.TriggeredTimer(provider.Name); + if (_disposed) + { + // We were waiting and the registration has been disposed in meanwhile + return; + } + + try + { + provider.Refresh(); + ScheduleTimer(provider); + await Callback.Invoke(provider.Password != null).ConfigureAwait(false); + TimerBasedCredentialRefresherEventSource.Log.RefreshedCredentials(provider.Name, true); + } + catch (Exception) + { + await Callback.Invoke(false).ConfigureAwait(false); + TimerBasedCredentialRefresherEventSource.Log.RefreshedCredentials(provider.Name, false); + } + }; + newTimer.Enabled = true; + newTimer.AutoReset = false; + TimerBasedCredentialRefresherEventSource.Log.ScheduledTimer(provider.Name, newTimer.Interval); + var oldTimer = _timer; + _timer = newTimer; + oldTimer?.Dispose(); + } + + public void Dispose() + { + if (_disposed) + { + throw new ObjectDisposedException(GetType().FullName); + } + try { - provider.Refresh(); - scheduleTimer(provider, callback); - callback.Invoke(provider.Password != null); - TimerBasedCredentialRefresherEventSource.Log.RefreshedCredentials(provider.Name, true); + _timer?.Stop(); + _disposed = true; } - catch (Exception) + finally { - callback.Invoke(false); - TimerBasedCredentialRefresherEventSource.Log.RefreshedCredentials(provider.Name, false); + _timer?.Dispose(); + _timer = null; } + } - }; - timer.Enabled = true; - timer.AutoReset = false; - TimerBasedCredentialRefresherEventSource.Log.ScheduledTimer(provider.Name, timer.Interval); - return timer; } + } class NoOpCredentialsRefresher : ICredentialsRefresher diff --git a/projects/Test/Unit/TestTimerBasedCredentialRefresher.cs b/projects/Test/Unit/TestTimerBasedCredentialRefresher.cs index faed6f03f7..99ecbf03ca 100644 --- a/projects/Test/Unit/TestTimerBasedCredentialRefresher.cs +++ b/projects/Test/Unit/TestTimerBasedCredentialRefresher.cs @@ -43,7 +43,7 @@ public class MockCredentialsProvider : ICredentialsProvider private readonly ITestOutputHelper _testOutputHelper; private readonly TimeSpan? _validUntil = TimeSpan.FromSeconds(1); private Exception _ex = null; - private bool _refreshCalled = false; + private int _refreshCalledTimes = 0; public MockCredentialsProvider(ITestOutputHelper testOutputHelper) { @@ -56,11 +56,11 @@ public MockCredentialsProvider(ITestOutputHelper testOutputHelper, TimeSpan vali _validUntil = validUntil; } - public bool RefreshCalled + public int RefreshCalledTimes { get { - return _refreshCalled; + return _refreshCalledTimes; } } @@ -87,7 +87,7 @@ public string Password public void Refresh() { - _refreshCalled = true; + _refreshCalledTimes++; } public void PasswordThrows(Exception ex) @@ -145,7 +145,50 @@ Task cb(bool arg) _refresher.Register(credentialsProvider, cb); Assert.True(await tcs.Task); - Assert.True(credentialsProvider.RefreshCalled); + Assert.True(credentialsProvider.RefreshCalledTimes > 0); + Assert.True(_refresher.Unregister(credentialsProvider)); + } + } + } + + [Fact] + public async Task TestRefreshTokenUpdateCallback() + { + var tcs1 = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + var tcs2 = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + int cb1CalledTimes = 0; + int cb2CalledTimes = 0; + + using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5))) + { + using (CancellationTokenRegistration ctr = cts.Token.Register(() => { tcs1.TrySetCanceled(); tcs2.TrySetCanceled(); })) + { + var credentialsProvider = new MockCredentialsProvider(_testOutputHelper, TimeSpan.FromSeconds(1)); + + Task cb1(bool arg) + { + cb1CalledTimes++; + tcs1.SetResult(arg); + return Task.CompletedTask; + } + + Task cb2(bool arg) + { + cb2CalledTimes++; + tcs2.SetResult(arg); + return Task.CompletedTask; + } + + _refresher.Register(credentialsProvider, cb1); + Assert.True(await tcs1.Task); + Assert.True(credentialsProvider.RefreshCalledTimes == 1); + Assert.True(cb1CalledTimes == 1); + _refresher.Register(credentialsProvider, cb2); + Assert.True(await tcs2.Task); + Assert.True(credentialsProvider.RefreshCalledTimes == 2); + Assert.True(cb2CalledTimes == 1); + Assert.True(cb1CalledTimes == 1); + Assert.True(_refresher.Unregister(credentialsProvider)); } } @@ -172,7 +215,7 @@ Task cb(bool arg) _refresher.Register(credentialsProvider, cb); Assert.False(await tcs.Task); - Assert.True(credentialsProvider.RefreshCalled); + Assert.True(credentialsProvider.RefreshCalledTimes > 0); Assert.True(_refresher.Unregister(credentialsProvider)); } }