Skip to content

Commit 18859f5

Browse files
committed
Auth Tests
Signed-off-by: Gabriele Santomaggio <[email protected]>
1 parent 236492f commit 18859f5

File tree

7 files changed

+217
-70
lines changed

7 files changed

+217
-70
lines changed

RabbitMQ.AMQP.Client/ConnectionSettings.cs

Lines changed: 43 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -154,11 +154,10 @@ public ConnectionSettingsBuilder UriSelector(IUriSelector uriSelector)
154154
return this;
155155
}
156156

157-
public ConnectionSettingsBuilder OAuth2Options(OAuth2Options? oAuth2Options = null)
157+
public ConnectionSettingsBuilder OAuth2Options(OAuth2Options? oAuth2Options)
158158
{
159159
_oAuth2Options = oAuth2Options;
160160
return this;
161-
162161
}
163162

164163
public ConnectionSettings Build()
@@ -171,26 +170,26 @@ public ConnectionSettings Build()
171170
_containerId, _saslMechanism,
172171
_recoveryConfiguration,
173172
_maxFrameSize,
174-
_tlsSettings);
173+
_tlsSettings,
174+
_oAuth2Options);
175175
}
176-
else if (_uris is not null)
176+
177+
if (_uris is not null)
177178
{
178179
return new ClusterConnectionSettings(_uris,
179180
_uriSelector,
180181
_containerId, _saslMechanism,
181182
_recoveryConfiguration,
182183
_maxFrameSize,
183-
_tlsSettings);
184-
}
185-
else
186-
{
187-
return new ConnectionSettings(_scheme, _host, _port, _user,
188-
_password, _virtualHost,
189-
_containerId, _saslMechanism,
190-
_recoveryConfiguration,
191-
_maxFrameSize,
192-
_tlsSettings);
184+
_tlsSettings, _oAuth2Options);
193185
}
186+
187+
return new ConnectionSettings(_scheme, _host, _port, _user,
188+
_password, _virtualHost,
189+
_containerId, _saslMechanism,
190+
_recoveryConfiguration,
191+
_maxFrameSize,
192+
_tlsSettings, _oAuth2Options);
194193
}
195194

196195
private void ValidateUris()
@@ -214,14 +213,14 @@ public class ConnectionSettings : IEquatable<ConnectionSettings>
214213
private readonly TlsSettings? _tlsSettings;
215214
private readonly SaslMechanism _saslMechanism = SaslMechanism.Plain;
216215
private readonly IRecoveryConfiguration _recoveryConfiguration = new RecoveryConfiguration();
217-
private readonly OAuth2Options? _oAuth2Options = null;
218216

219217
public ConnectionSettings(Uri uri,
220218
string? containerId = null,
221219
SaslMechanism? saslMechanism = null,
222220
IRecoveryConfiguration? recoveryConfiguration = null,
223221
uint? maxFrameSize = null,
224-
TlsSettings? tlsSettings = null)
222+
TlsSettings? tlsSettings = null,
223+
OAuth2Options? oAuth2Options = null)
225224
: this(containerId, saslMechanism, recoveryConfiguration, maxFrameSize, tlsSettings)
226225
{
227226
(string? user, string? password) = ProcessUserInfo(uri);
@@ -241,6 +240,14 @@ public ConnectionSettings(Uri uri,
241240
path: "/",
242241
scheme: scheme);
243242

243+
if (oAuth2Options is not null)
244+
{
245+
// in case of OAuth2, we need to use plain mechanism
246+
_saslMechanism = SaslMechanism.Plain;
247+
_address = new Address(_address.Host, _address.Port, "", oAuth2Options.Token, _address.Path,
248+
_address.Scheme);
249+
}
250+
244251
_tlsSettings = InitTlsSettings();
245252
}
246253

@@ -254,7 +261,8 @@ public ConnectionSettings(string scheme,
254261
SaslMechanism? saslMechanism = null,
255262
IRecoveryConfiguration? recoveryConfiguration = null,
256263
uint? maxFrameSize = null,
257-
TlsSettings? tlsSettings = null)
264+
TlsSettings? tlsSettings = null,
265+
OAuth2Options? oAuth2Options = null)
258266
: this(containerId, saslMechanism, recoveryConfiguration, maxFrameSize, tlsSettings)
259267
{
260268
if (false == Utils.IsValidScheme(scheme))
@@ -274,6 +282,14 @@ public ConnectionSettings(string scheme,
274282
_virtualHost = virtualHost;
275283
}
276284

285+
if (oAuth2Options is not null)
286+
{
287+
// in case of OAuth2, we need to use plain mechanism
288+
_saslMechanism = SaslMechanism.Plain;
289+
_address = new Address(_address.Host, _address.Port, "", oAuth2Options.Token, _address.Path,
290+
_address.Scheme);
291+
}
292+
277293
_tlsSettings = InitTlsSettings();
278294
}
279295

@@ -282,8 +298,7 @@ protected ConnectionSettings(
282298
SaslMechanism? saslMechanism = null,
283299
IRecoveryConfiguration? recoveryConfiguration = null,
284300
uint? maxFrameSize = null,
285-
TlsSettings? tlsSettings = null,
286-
OAuth2Options? oAuth2Options = null)
301+
TlsSettings? tlsSettings = null)
287302
{
288303
if (containerId is not null)
289304
{
@@ -300,15 +315,6 @@ protected ConnectionSettings(
300315
_recoveryConfiguration = recoveryConfiguration;
301316
}
302317

303-
_oAuth2Options = oAuth2Options;
304-
if (_oAuth2Options is not null)
305-
{
306-
// in case of OAuth2, we need to use plain mechanism
307-
_saslMechanism = SaslMechanism.Plain;
308-
_address = new Address(_address.Host, _address.Port, _address.User, _oAuth2Options.Token, _address.Path,
309-
_address.Scheme);
310-
}
311-
312318
if (maxFrameSize is not null)
313319
{
314320
_maxFrameSize = (uint)maxFrameSize;
@@ -477,7 +483,7 @@ public ClusterConnectionSettings(IEnumerable<Uri> uris,
477483
uint? maxFrameSize = null,
478484
TlsSettings? tlsSettings = null,
479485
OAuth2Options? oAuth2Options = null)
480-
: base(containerId, saslMechanism, recoveryConfiguration, maxFrameSize, tlsSettings, oAuth2Options)
486+
: base(containerId, saslMechanism, recoveryConfiguration, maxFrameSize, tlsSettings)
481487
{
482488
_uris = uris.ToList();
483489
if (_uris.Count == 0)
@@ -526,6 +532,14 @@ public ClusterConnectionSettings(IEnumerable<Uri> uris,
526532
path: "/",
527533
scheme: scheme);
528534

535+
// if (oAuth2Options is not null)
536+
// {
537+
// // in case of OAuth2, we need to use plain mechanism
538+
// _saslMechanism = SaslMechanism.Plain;
539+
// _address = new Address(_address.Host, _address.Port, "", oAuth2Options.Token, _address.Path,
540+
// _address.Scheme);
541+
// }
542+
529543
_uriToAddress[uri] = address;
530544

531545
if (first)

RabbitMQ.AMQP.Client/PublicAPI.Unshipped.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ RabbitMQ.AMQP.Client.ConnectionException
6464
RabbitMQ.AMQP.Client.ConnectionException.ConnectionException(string! message) -> void
6565
RabbitMQ.AMQP.Client.ConnectionException.ConnectionException(string! message, System.Exception! innerException) -> void
6666
RabbitMQ.AMQP.Client.ConnectionSettings
67-
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(string! scheme, string! host, int port, string? user = null, string? password = null, string? virtualHost = null, string! containerId = "", RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null) -> void
68-
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(string? containerId = null, RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null, RabbitMQ.AMQP.Client.OAuth2Options? oAuth2Options = null) -> void
69-
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(System.Uri! uri, string? containerId = null, RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null) -> void
67+
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(string! scheme, string! host, int port, string? user = null, string? password = null, string? virtualHost = null, string! containerId = "", RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null, RabbitMQ.AMQP.Client.OAuth2Options? oAuth2Options = null) -> void
68+
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(string? containerId = null, RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null) -> void
69+
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(System.Uri! uri, string? containerId = null, RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null, RabbitMQ.AMQP.Client.OAuth2Options? oAuth2Options = null) -> void
7070
RabbitMQ.AMQP.Client.ConnectionSettings.ContainerId.get -> string!
7171
RabbitMQ.AMQP.Client.ConnectionSettings.Host.get -> string!
7272
RabbitMQ.AMQP.Client.ConnectionSettings.MaxFrameSize.get -> uint
@@ -88,7 +88,7 @@ RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.ConnectionSettingsBuilder() -> vo
8888
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.ContainerId(string! containerId) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
8989
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.Host(string! host) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
9090
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.MaxFrameSize(uint maxFrameSize) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
91-
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.OAuth2Options(RabbitMQ.AMQP.Client.OAuth2Options? oAuth2Options = null) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
91+
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.OAuth2Options(RabbitMQ.AMQP.Client.OAuth2Options? oAuth2Options) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
9292
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.Password(string! password) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
9393
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.Port(int port) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
9494
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.RecoveryConfiguration(RabbitMQ.AMQP.Client.IRecoveryConfiguration! recoveryConfiguration) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!

Tests/OAuth2Tests.cs

Lines changed: 62 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55
using System;
66
using System.Collections.Generic;
77
using System.IdentityModel.Tokens.Jwt;
8+
using System.Linq;
89
using System.Security.Claims;
910
using System.Security.Cryptography;
1011
using System.Text;
12+
using System.Threading;
1113
using System.Threading.Tasks;
14+
using Amqp;
1215
using Microsoft.IdentityModel.Tokens;
1316
using RabbitMQ.AMQP.Client;
1417
using RabbitMQ.AMQP.Client.Impl;
1518
using Xunit;
19+
using IConnection = RabbitMQ.AMQP.Client.IConnection;
1620

1721
namespace Tests
1822
{
@@ -23,7 +27,7 @@ public class OAuth2Tests
2327
private const string Audience = "rabbitmq";
2428

2529
[Fact]
26-
public async Task ConnectToRabbitMqWithOAuth2Token()
30+
public async Task ConnectToRabbitMqWithOAuth2TokenShouldSuccess()
2731
{
2832
IConnection connection = await AmqpConnection.CreateAsync(
2933
ConnectionSettingsBuilder.Create()
@@ -36,55 +40,76 @@ public async Task ConnectToRabbitMqWithOAuth2Token()
3640
await connection.CloseAsync();
3741
}
3842

39-
private static string GenerateToken(DateTime expiration)
43+
[Fact]
44+
public async Task ConnectToRabbitMqWithOAuth2TokenShouldDisconnectAfterTimeout()
4045
{
41-
byte[] decodedKey = Convert.FromBase64String(Base64Key);
46+
IConnection connection = await AmqpConnection.CreateAsync(
47+
ConnectionSettingsBuilder.Create()
48+
.Host("localhost")
49+
.Port(5672)
50+
.RecoveryConfiguration(new RecoveryConfiguration().Activated(false).Topology(false))
51+
.OAuth2Options(new OAuth2Options(GenerateToken(DateTime.UtcNow.AddMilliseconds(1_000))))
52+
.Build());
4253

43-
var claims = new List<Claim>
54+
Assert.NotNull(connection);
55+
Assert.Equal(State.Open, connection.State);
56+
State? stateFrom = null;
57+
State? stateTo = null;
58+
Error? stateError = null;
59+
TaskCompletionSource<bool> tcs = new();
60+
connection.ChangeState += (sender, from, to, error) =>
4461
{
45-
new (JwtRegisteredClaimNames.Iss, "unit_test"),
46-
new (JwtRegisteredClaimNames.Aud, Audience),
47-
new (JwtRegisteredClaimNames.Exp,
48-
new DateTimeOffset(expiration).ToUniversalTime().ToUnixTimeSeconds().ToString()),
49-
new ("scope", "rabbitmq.configure:*/*"),
50-
new ("scope", "rabbitmq.write:*/*"),
51-
new ("scope", "rabbitmq.read:*/*"),
52-
new ("random", RandomString(6))
62+
stateFrom = from;
63+
stateTo = to;
64+
stateError = error;
65+
tcs.SetResult(true);
5366
};
5467

55-
var tokenHandler = new JwtSecurityTokenHandler();
56-
var claimIdentity = new ClaimsIdentity(claims);
57-
var tokenDescriptor = new SecurityTokenDescriptor
68+
await tcs.Task;
69+
Assert.NotNull(stateFrom);
70+
Assert.NotNull(stateTo);
71+
Assert.NotNull(stateError);
72+
Assert.NotNull(stateError.ErrorCode);
73+
Assert.Equal(State.Open, stateFrom);
74+
Assert.Equal(State.Closed, stateTo);
75+
Assert.Equal(State.Closed, connection.State);
76+
Assert.Contains(stateError.ErrorCode, "amqp:unauthorized-access");
77+
}
78+
79+
private static string GenerateToken(DateTime duration)
80+
{
81+
byte[] decodedKey = Convert.FromBase64String(Base64Key);
82+
83+
var claims = new[]
5884
{
59-
Subject = claimIdentity,
60-
Expires = expiration,
61-
SigningCredentials =
62-
new SigningCredentials(new SymmetricSecurityKey(decodedKey),
63-
SecurityAlgorithms.HmacSha256Signature),
64-
Claims = new Dictionary<string, object> { ["kid"] = "token-key" }
85+
new Claim(JwtRegisteredClaimNames.Iss, "unit_test"),
86+
new Claim(JwtRegisteredClaimNames.Aud, Audience),
87+
new Claim(JwtRegisteredClaimNames.Exp, new DateTimeOffset(duration).ToUnixTimeSeconds().ToString()),
88+
new Claim("scope", "rabbitmq.configure:*/* rabbitmq.write:*/* rabbitmq.read:*/*"),
89+
new Claim("random", GenerateRandomString(6))
6590
};
6691

67-
var token = tokenHandler.CreateToken(tokenDescriptor);
92+
var key = new SymmetricSecurityKey(decodedKey);
93+
var creds = new SigningCredentials(key, SecurityAlgorithms.HmacSha256);
94+
95+
var token = new JwtSecurityToken(
96+
claims: claims,
97+
expires: duration,
98+
signingCredentials: creds
99+
);
100+
101+
token.Header["kid"] = "token-key";
102+
103+
var tokenHandler = new JwtSecurityTokenHandler();
68104
return tokenHandler.WriteToken(token);
69105
}
70106

71-
private static string RandomString(int length)
107+
private static string GenerateRandomString(int length)
72108
{
73109
const string chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
74-
var result = new StringBuilder(length);
75-
76-
using (var rng = RandomNumberGenerator.Create())
77-
{
78-
byte[] byteArray = new byte[length];
79-
rng.GetBytes(byteArray);
80-
81-
foreach (byte byteValue in byteArray)
82-
{
83-
result.Append(chars[byteValue % chars.Length]);
84-
}
85-
}
86-
87-
return result.ToString();
110+
var random = new Random();
111+
return new string(Enumerable.Repeat(chars, length)
112+
.Select(s => s[random.Next(s.Length)]).ToArray());
88113
}
89114
}
90115
}

docs/Examples/OAuth2/OAuth2.csproj

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
<Project Sdk="Microsoft.NET.Sdk">
2+
3+
<PropertyGroup>
4+
<OutputType>Exe</OutputType>
5+
<TargetFramework>net8.0</TargetFramework>
6+
<ImplicitUsings>enable</ImplicitUsings>
7+
<Nullable>enable</Nullable>
8+
</PropertyGroup>
9+
10+
<ItemGroup>
11+
<ProjectReference Include="..\..\..\RabbitMQ.AMQP.Client\RabbitMQ.AMQP.Client.csproj" />
12+
</ItemGroup>
13+
14+
<ItemGroup>
15+
<PackageReference Include="System.IdentityModel.Tokens.Jwt" />
16+
</ItemGroup>
17+
18+
</Project>

docs/Examples/OAuth2/Program.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// See https://aka.ms/new-console-template for more information
2+
3+
using System.Diagnostics;
4+
using OAuth2;
5+
using RabbitMQ.AMQP.Client;
6+
using RabbitMQ.AMQP.Client.Impl;
7+
using Trace = Amqp.Trace;
8+
using TraceLevel = Amqp.TraceLevel;
9+
10+
Trace.TraceLevel = TraceLevel.Verbose;
11+
12+
ConsoleTraceListener consoleListener = new();
13+
Trace.TraceListener = (l, f, a) =>
14+
consoleListener.WriteLine($"[{DateTime.Now}] [{l}] - {f}");
15+
16+
IConnection connection = await AmqpConnection.CreateAsync(
17+
ConnectionSettingsBuilder.Create()
18+
.Host("localhost")
19+
.Port(5672)
20+
.OAuth2Options(new OAuth2Options(Token.GenerateToken(DateTime.UtcNow.AddMilliseconds(1500))))
21+
.Build()).ConfigureAwait(false);
22+
23+
Trace.WriteLine(TraceLevel.Information, $"Connected to the broker {connection} successfully");
24+
Trace.WriteLine(TraceLevel.Information, $"Connection status {connection.State}");
25+
26+
Thread.Sleep(TimeSpan.FromSeconds(15));
27+
28+
Console.WriteLine("Connection state: " + connection.State);

0 commit comments

Comments
 (0)