Skip to content
128 changes: 92 additions & 36 deletions projects/RabbitMQ.Client/client/api/ICredentialsRefresher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
//---------------------------------------------------------------------------

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.Diagnostics.Tracing;
using System.Threading.Tasks;
using System.Timers;

namespace RabbitMQ.Client
{
public interface ICredentialsRefresher
Expand Down Expand Up @@ -71,7 +73,9 @@ public class TimerBasedCredentialRefresherEventSource : EventSource

public class TimerBasedCredentialRefresher : ICredentialsRefresher
{
private readonly ConcurrentDictionary<ICredentialsProvider, System.Timers.Timer> _registrations = new ConcurrentDictionary<ICredentialsProvider, System.Timers.Timer>();
private readonly IDictionary<ICredentialsProvider, TimerRegistration> _registrations =
new Dictionary<ICredentialsProvider, TimerRegistration>();
private readonly object _lockObj = new();

public ICredentialsProvider Register(ICredentialsProvider provider, ICredentialsRefresher.NotifyCredentialRefreshedAsync callback)
{
Expand All @@ -80,65 +84,117 @@ public ICredentialsProvider Register(ICredentialsProvider provider, ICredentials
return provider;
}

if (_registrations.TryAdd(provider, scheduleTimer(provider, callback)))
lock (_lockObj)
{
if (_registrations.TryGetValue(provider, out var registration))
{
registration.Callback = callback;
TimerBasedCredentialRefresherEventSource.Log.AlreadyRegistered(provider.Name);
return provider;
}

registration = new TimerRegistration(callback);
_registrations.Add(provider, registration);
registration.ScheduleTimer(provider);

TimerBasedCredentialRefresherEventSource.Log.Registered(provider.Name);
}
else
{
TimerBasedCredentialRefresherEventSource.Log.AlreadyRegistered(provider.Name);
}

return provider;
}

public bool Unregister(ICredentialsProvider provider)
{
if (_registrations.TryRemove(provider, out System.Timers.Timer? timer))
lock (_lockObj)
{
try
if (_registrations.Remove(provider, out var registration))
{
TimerBasedCredentialRefresherEventSource.Log.Unregistered(provider.Name);
timer.Stop();
registration.Dispose();
return true;
}
finally
{
timer.Dispose();
}
return true;
}
else
{
return false;
}

return false;
}

private System.Timers.Timer scheduleTimer(ICredentialsProvider provider, ICredentialsRefresher.NotifyCredentialRefreshedAsync callback)
private class TimerRegistration : IDisposable
{
System.Timers.Timer timer = new System.Timers.Timer();
timer.Interval = provider.ValidUntil!.Value.TotalMilliseconds * (1.0 - (1 / 3.0));
timer.Elapsed += (o, e) =>

private System.Timers.Timer? _timer;
private bool _disposed;

public ICredentialsRefresher.NotifyCredentialRefreshedAsync Callback { get; set; }

public TimerRegistration(ICredentialsRefresher.NotifyCredentialRefreshedAsync callback)
{
TimerBasedCredentialRefresherEventSource.Log.TriggeredTimer(provider.Name);
Callback = callback;
}

public void ScheduleTimer(ICredentialsProvider provider)
{
if (provider.ValidUntil == null)
{
throw new ArgumentNullException(nameof(provider.ValidUntil) + " of " + provider.GetType().Name + " was null");
}
if (_disposed)
{
return;
}

var newTimer = new Timer();
newTimer.Interval = provider.ValidUntil.Value.TotalMilliseconds * (1.0 - 1 / 3.0);
newTimer.Elapsed += async (o, e) =>
{
TimerBasedCredentialRefresherEventSource.Log.TriggeredTimer(provider.Name);
if (_disposed)
{
// We were waiting and the registration has been disposed in meanwhile
return;
}

try
{
provider.Refresh();
ScheduleTimer(provider);
await Callback.Invoke(provider.Password != null).ConfigureAwait(false);
TimerBasedCredentialRefresherEventSource.Log.RefreshedCredentials(provider.Name, true);
}
catch (Exception)
{
await Callback.Invoke(false).ConfigureAwait(false);
TimerBasedCredentialRefresherEventSource.Log.RefreshedCredentials(provider.Name, false);
}
};
newTimer.Enabled = true;
newTimer.AutoReset = false;
TimerBasedCredentialRefresherEventSource.Log.ScheduledTimer(provider.Name, newTimer.Interval);
var oldTimer = _timer;
_timer = newTimer;
oldTimer?.Dispose();
}

public void Dispose()
{
if (_disposed)
{
throw new ObjectDisposedException(GetType().FullName);
}

try
{
provider.Refresh();
scheduleTimer(provider, callback);
callback.Invoke(provider.Password != null);
TimerBasedCredentialRefresherEventSource.Log.RefreshedCredentials(provider.Name, true);
_timer?.Stop();
_disposed = true;
}
catch (Exception)
finally
{
callback.Invoke(false);
TimerBasedCredentialRefresherEventSource.Log.RefreshedCredentials(provider.Name, false);
_timer?.Dispose();
_timer = null;
}
}

};
timer.Enabled = true;
timer.AutoReset = false;
TimerBasedCredentialRefresherEventSource.Log.ScheduledTimer(provider.Name, timer.Interval);
return timer;
}

}

class NoOpCredentialsRefresher : ICredentialsRefresher
Expand Down
55 changes: 49 additions & 6 deletions projects/Test/Unit/TestTimerBasedCredentialRefresher.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class MockCredentialsProvider : ICredentialsProvider
private readonly ITestOutputHelper _testOutputHelper;
private readonly TimeSpan? _validUntil = TimeSpan.FromSeconds(1);
private Exception _ex = null;
private bool _refreshCalled = false;
private int _refreshCalledTimes = 0;

public MockCredentialsProvider(ITestOutputHelper testOutputHelper)
{
Expand All @@ -56,11 +56,11 @@ public MockCredentialsProvider(ITestOutputHelper testOutputHelper, TimeSpan vali
_validUntil = validUntil;
}

public bool RefreshCalled
public int RefreshCalledTimes
{
get
{
return _refreshCalled;
return _refreshCalledTimes;
}
}

Expand All @@ -87,7 +87,7 @@ public string Password

public void Refresh()
{
_refreshCalled = true;
_refreshCalledTimes++;
}

public void PasswordThrows(Exception ex)
Expand Down Expand Up @@ -145,7 +145,50 @@ Task cb(bool arg)

_refresher.Register(credentialsProvider, cb);
Assert.True(await tcs.Task);
Assert.True(credentialsProvider.RefreshCalled);
Assert.True(credentialsProvider.RefreshCalledTimes > 0);
Assert.True(_refresher.Unregister(credentialsProvider));
}
}
}

[Fact]
public async Task TestRefreshTokenUpdateCallback()
{
var tcs1 = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
var tcs2 = new TaskCompletionSource<bool>(TaskCreationOptions.RunContinuationsAsynchronously);
int cb1CalledTimes = 0;
int cb2CalledTimes = 0;

using (var cts = new CancellationTokenSource(TimeSpan.FromSeconds(5)))
{
using (CancellationTokenRegistration ctr = cts.Token.Register(() => { tcs1.TrySetCanceled(); tcs2.TrySetCanceled(); }))
{
var credentialsProvider = new MockCredentialsProvider(_testOutputHelper, TimeSpan.FromSeconds(1));

Task cb1(bool arg)
{
cb1CalledTimes++;
tcs1.SetResult(arg);
return Task.CompletedTask;
}

Task cb2(bool arg)
{
cb2CalledTimes++;
tcs2.SetResult(arg);
return Task.CompletedTask;
}

_refresher.Register(credentialsProvider, cb1);
Assert.True(await tcs1.Task);
Assert.True(credentialsProvider.RefreshCalledTimes == 1);
Assert.True(cb1CalledTimes == 1);
_refresher.Register(credentialsProvider, cb2);
Assert.True(await tcs2.Task);
Assert.True(credentialsProvider.RefreshCalledTimes == 2);
Assert.True(cb2CalledTimes == 1);
Assert.True(cb1CalledTimes == 1);

Assert.True(_refresher.Unregister(credentialsProvider));
}
}
Expand All @@ -172,7 +215,7 @@ Task cb(bool arg)

_refresher.Register(credentialsProvider, cb);
Assert.False(await tcs.Task);
Assert.True(credentialsProvider.RefreshCalled);
Assert.True(credentialsProvider.RefreshCalledTimes > 0);
Assert.True(_refresher.Unregister(credentialsProvider));
}
}
Expand Down