Skip to content

Commit 236492f

Browse files
committed
auth configuration
Signed-off-by: Gabriele Santomaggio <[email protected]>
1 parent 4ec6264 commit 236492f

File tree

5 files changed

+140
-9
lines changed

5 files changed

+140
-9
lines changed

Directory.Packages.props

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
<!-- Tests -->
1414
<PackageVersion Include="Microsoft.Extensions.Diagnostics" Version="9.0.0" />
1515
<PackageVersion Include="Microsoft.Extensions.Diagnostics.Testing" Version="9.0.0" />
16+
<PackageVersion Include="System.IdentityModel.Tokens.Jwt" Version="8.6.1" />
1617
<PackageVersion Include="System.Text.Json" Version="9.0.0" />
1718
<PackageVersion Include="xunit" Version="2.9.2" />
1819
<PackageVersion Include="xunit.runner.visualstudio" Version="2.8.2" />
@@ -39,4 +40,4 @@
3940
<GlobalPackageReference Include="Microsoft.SourceLink.GitHub" Version="8.0.0" />
4041
<GlobalPackageReference Include="MinVer" Version="6.0.0" />
4142
</ItemGroup>
42-
</Project>
43+
</Project>

RabbitMQ.AMQP.Client/ConnectionSettings.cs

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ public class ConnectionSettingsBuilder
4141
private Uri? _uri;
4242
private List<Uri>? _uris;
4343
private IUriSelector? _uriSelector;
44+
private OAuth2Options? _oAuth2Options;
4445

4546
public static ConnectionSettingsBuilder Create()
4647
{
@@ -104,6 +105,7 @@ public ConnectionSettingsBuilder MaxFrameSize(uint maxFrameSize)
104105
throw new ArgumentOutOfRangeException(nameof(maxFrameSize),
105106
"maxFrameSize must be 0 (no limit) or greater than or equal to 512");
106107
}
108+
107109
return this;
108110
}
109111

@@ -152,6 +154,13 @@ public ConnectionSettingsBuilder UriSelector(IUriSelector uriSelector)
152154
return this;
153155
}
154156

157+
public ConnectionSettingsBuilder OAuth2Options(OAuth2Options? oAuth2Options = null)
158+
{
159+
_oAuth2Options = oAuth2Options;
160+
return this;
161+
162+
}
163+
155164
public ConnectionSettings Build()
156165
{
157166
// TODO this should do something similar to consolidate in the Java code
@@ -205,6 +214,7 @@ public class ConnectionSettings : IEquatable<ConnectionSettings>
205214
private readonly TlsSettings? _tlsSettings;
206215
private readonly SaslMechanism _saslMechanism = SaslMechanism.Plain;
207216
private readonly IRecoveryConfiguration _recoveryConfiguration = new RecoveryConfiguration();
217+
private readonly OAuth2Options? _oAuth2Options = null;
208218

209219
public ConnectionSettings(Uri uri,
210220
string? containerId = null,
@@ -272,7 +282,8 @@ protected ConnectionSettings(
272282
SaslMechanism? saslMechanism = null,
273283
IRecoveryConfiguration? recoveryConfiguration = null,
274284
uint? maxFrameSize = null,
275-
TlsSettings? tlsSettings = null)
285+
TlsSettings? tlsSettings = null,
286+
OAuth2Options? oAuth2Options = null)
276287
{
277288
if (containerId is not null)
278289
{
@@ -289,6 +300,15 @@ protected ConnectionSettings(
289300
_recoveryConfiguration = recoveryConfiguration;
290301
}
291302

303+
_oAuth2Options = oAuth2Options;
304+
if (_oAuth2Options is not null)
305+
{
306+
// in case of OAuth2, we need to use plain mechanism
307+
_saslMechanism = SaslMechanism.Plain;
308+
_address = new Address(_address.Host, _address.Port, _address.User, _oAuth2Options.Token, _address.Path,
309+
_address.Scheme);
310+
}
311+
292312
if (maxFrameSize is not null)
293313
{
294314
_maxFrameSize = (uint)maxFrameSize;
@@ -408,7 +428,8 @@ protected static string ProcessUriSegmentsForVirtualHost(Uri uri)
408428
// that has at least the path segment "/"
409429
if (uri.Segments.Length > 2)
410430
{
411-
throw new ArgumentException($"Multiple segments in path of AMQP URI: {string.Join(", ", uri.Segments)}");
431+
throw new ArgumentException(
432+
$"Multiple segments in path of AMQP URI: {string.Join(", ", uri.Segments)}");
412433
}
413434

414435
if (uri.Segments.Length == 2)
@@ -454,16 +475,17 @@ public ClusterConnectionSettings(IEnumerable<Uri> uris,
454475
SaslMechanism? saslMechanism = null,
455476
IRecoveryConfiguration? recoveryConfiguration = null,
456477
uint? maxFrameSize = null,
457-
TlsSettings? tlsSettings = null)
458-
: base(containerId, saslMechanism, recoveryConfiguration, maxFrameSize, tlsSettings)
478+
TlsSettings? tlsSettings = null,
479+
OAuth2Options? oAuth2Options = null)
480+
: base(containerId, saslMechanism, recoveryConfiguration, maxFrameSize, tlsSettings, oAuth2Options)
459481
{
460482
_uris = uris.ToList();
461483
if (_uris.Count == 0)
462484
{
463485
throw new ArgumentOutOfRangeException(nameof(uris), "At least one Uri is required.");
464486
}
465487

466-
_uriToAddress = new(_uris.Count);
488+
_uriToAddress = new Dictionary<Uri, Address>(_uris.Count);
467489

468490
if (uriSelector is not null)
469491
{
@@ -492,7 +514,8 @@ public ClusterConnectionSettings(IEnumerable<Uri> uris,
492514
string thisVirtualHost = ProcessUriSegmentsForVirtualHost(uri);
493515
if (false == thisVirtualHost.Equals(tmpVirtualHost, StringComparison.InvariantCultureIgnoreCase))
494516
{
495-
throw new ArgumentException($"All AMQP URIs must use the same virtual host. Expected '{tmpVirtualHost}', got '{thisVirtualHost}'");
517+
throw new ArgumentException(
518+
$"All AMQP URIs must use the same virtual host. Expected '{tmpVirtualHost}', got '{thisVirtualHost}'");
496519
}
497520
}
498521

@@ -551,6 +574,7 @@ public override int GetHashCode()
551574
{
552575
hashCode ^= _uris[i].GetHashCode();
553576
}
577+
554578
return hashCode;
555579
}
556580

@@ -600,4 +624,14 @@ private bool trustEverythingCertValidationCallback(object sender, X509Certificat
600624
return (sslPolicyErrors & ~AcceptablePolicyErrors) == SslPolicyErrors.None;
601625
}
602626
}
627+
628+
public class OAuth2Options
629+
{
630+
public OAuth2Options(string token)
631+
{
632+
Token = token;
633+
}
634+
635+
public string Token { get; set; }
636+
}
603637
}

RabbitMQ.AMQP.Client/PublicAPI.Unshipped.txt

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ RabbitMQ.AMQP.Client.ClassicQueueVersion
5959
RabbitMQ.AMQP.Client.ClassicQueueVersion.V1 = 0 -> RabbitMQ.AMQP.Client.ClassicQueueVersion
6060
RabbitMQ.AMQP.Client.ClassicQueueVersion.V2 = 1 -> RabbitMQ.AMQP.Client.ClassicQueueVersion
6161
RabbitMQ.AMQP.Client.ClusterConnectionSettings
62-
RabbitMQ.AMQP.Client.ClusterConnectionSettings.ClusterConnectionSettings(System.Collections.Generic.IEnumerable<System.Uri!>! uris, RabbitMQ.AMQP.Client.IUriSelector? uriSelector = null, string? containerId = null, RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null) -> void
62+
RabbitMQ.AMQP.Client.ClusterConnectionSettings.ClusterConnectionSettings(System.Collections.Generic.IEnumerable<System.Uri!>! uris, RabbitMQ.AMQP.Client.IUriSelector? uriSelector = null, string? containerId = null, RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null, RabbitMQ.AMQP.Client.OAuth2Options? oAuth2Options = null) -> void
6363
RabbitMQ.AMQP.Client.ConnectionException
6464
RabbitMQ.AMQP.Client.ConnectionException.ConnectionException(string! message) -> void
6565
RabbitMQ.AMQP.Client.ConnectionException.ConnectionException(string! message, System.Exception! innerException) -> void
6666
RabbitMQ.AMQP.Client.ConnectionSettings
6767
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(string! scheme, string! host, int port, string? user = null, string? password = null, string? virtualHost = null, string! containerId = "", RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null) -> void
68-
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(string? containerId = null, RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null) -> void
68+
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(string? containerId = null, RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null, RabbitMQ.AMQP.Client.OAuth2Options? oAuth2Options = null) -> void
6969
RabbitMQ.AMQP.Client.ConnectionSettings.ConnectionSettings(System.Uri! uri, string? containerId = null, RabbitMQ.AMQP.Client.SaslMechanism? saslMechanism = null, RabbitMQ.AMQP.Client.IRecoveryConfiguration? recoveryConfiguration = null, uint? maxFrameSize = null, RabbitMQ.AMQP.Client.TlsSettings? tlsSettings = null) -> void
7070
RabbitMQ.AMQP.Client.ConnectionSettings.ContainerId.get -> string!
7171
RabbitMQ.AMQP.Client.ConnectionSettings.Host.get -> string!
@@ -88,6 +88,7 @@ RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.ConnectionSettingsBuilder() -> vo
8888
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.ContainerId(string! containerId) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
8989
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.Host(string! host) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
9090
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.MaxFrameSize(uint maxFrameSize) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
91+
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.OAuth2Options(RabbitMQ.AMQP.Client.OAuth2Options? oAuth2Options = null) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
9192
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.Password(string! password) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
9293
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.Port(int port) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
9394
RabbitMQ.AMQP.Client.ConnectionSettingsBuilder.RecoveryConfiguration(RabbitMQ.AMQP.Client.IRecoveryConfiguration! recoveryConfiguration) -> RabbitMQ.AMQP.Client.ConnectionSettingsBuilder!
@@ -698,6 +699,10 @@ RabbitMQ.AMQP.Client.MetricsReporter.PublisherClosed() -> void
698699
RabbitMQ.AMQP.Client.MetricsReporter.PublisherOpened() -> void
699700
RabbitMQ.AMQP.Client.ModelException
700701
RabbitMQ.AMQP.Client.ModelException.ModelException(string! message) -> void
702+
RabbitMQ.AMQP.Client.OAuth2Options
703+
RabbitMQ.AMQP.Client.OAuth2Options.OAuth2Options(string! token) -> void
704+
RabbitMQ.AMQP.Client.OAuth2Options.Token.get -> string!
705+
RabbitMQ.AMQP.Client.OAuth2Options.Token.set -> void
701706
RabbitMQ.AMQP.Client.OutcomeState
702707
RabbitMQ.AMQP.Client.OutcomeState.Accepted = 0 -> RabbitMQ.AMQP.Client.OutcomeState
703708
RabbitMQ.AMQP.Client.OutcomeState.Rejected = 1 -> RabbitMQ.AMQP.Client.OutcomeState

Tests/OAuth2Tests.cs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// This source code is dual-licensed under the Apache License, version 2.0,
2+
// and the Mozilla Public License, version 2.0.
3+
// Copyright (c) 2017-2024 Broadcom. All Rights Reserved. The term "Broadcom" refers to Broadcom Inc. and/or its subsidiaries.
4+
5+
using System;
6+
using System.Collections.Generic;
7+
using System.IdentityModel.Tokens.Jwt;
8+
using System.Security.Claims;
9+
using System.Security.Cryptography;
10+
using System.Text;
11+
using System.Threading.Tasks;
12+
using Microsoft.IdentityModel.Tokens;
13+
using RabbitMQ.AMQP.Client;
14+
using RabbitMQ.AMQP.Client.Impl;
15+
using Xunit;
16+
17+
namespace Tests
18+
{
19+
public class OAuth2Tests
20+
{
21+
private const string Base64Key = "abcdefghijklmnopqrstuvwxyz0123456789ABCDEFGH";
22+
23+
private const string Audience = "rabbitmq";
24+
25+
[Fact]
26+
public async Task ConnectToRabbitMqWithOAuth2Token()
27+
{
28+
IConnection connection = await AmqpConnection.CreateAsync(
29+
ConnectionSettingsBuilder.Create()
30+
.Host("localhost")
31+
.Port(5672)
32+
.OAuth2Options(new OAuth2Options(GenerateToken(DateTime.UtcNow.AddMinutes(5))))
33+
.Build());
34+
35+
Assert.NotNull(connection);
36+
await connection.CloseAsync();
37+
}
38+
39+
private static string GenerateToken(DateTime expiration)
40+
{
41+
byte[] decodedKey = Convert.FromBase64String(Base64Key);
42+
43+
var claims = new List<Claim>
44+
{
45+
new (JwtRegisteredClaimNames.Iss, "unit_test"),
46+
new (JwtRegisteredClaimNames.Aud, Audience),
47+
new (JwtRegisteredClaimNames.Exp,
48+
new DateTimeOffset(expiration).ToUniversalTime().ToUnixTimeSeconds().ToString()),
49+
new ("scope", "rabbitmq.configure:*/*"),
50+
new ("scope", "rabbitmq.write:*/*"),
51+
new ("scope", "rabbitmq.read:*/*"),
52+
new ("random", RandomString(6))
53+
};
54+
55+
var tokenHandler = new JwtSecurityTokenHandler();
56+
var claimIdentity = new ClaimsIdentity(claims);
57+
var tokenDescriptor = new SecurityTokenDescriptor
58+
{
59+
Subject = claimIdentity,
60+
Expires = expiration,
61+
SigningCredentials =
62+
new SigningCredentials(new SymmetricSecurityKey(decodedKey),
63+
SecurityAlgorithms.HmacSha256Signature),
64+
Claims = new Dictionary<string, object> { ["kid"] = "token-key" }
65+
};
66+
67+
var token = tokenHandler.CreateToken(tokenDescriptor);
68+
return tokenHandler.WriteToken(token);
69+
}
70+
71+
private static string RandomString(int length)
72+
{
73+
const string chars = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
74+
var result = new StringBuilder(length);
75+
76+
using (var rng = RandomNumberGenerator.Create())
77+
{
78+
byte[] byteArray = new byte[length];
79+
rng.GetBytes(byteArray);
80+
81+
foreach (byte byteValue in byteArray)
82+
{
83+
result.Append(chars[byteValue % chars.Length]);
84+
}
85+
}
86+
87+
return result.ToString();
88+
}
89+
}
90+
}

Tests/Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
<PackageReference Include="Microsoft.Extensions.Diagnostics.Testing" />
3030
<PackageReference Include="Microsoft.Extensions.Diagnostics" />
3131
<PackageReference Include="Microsoft.NET.Test.Sdk" />
32+
<PackageReference Include="System.IdentityModel.Tokens.Jwt" />
3233
<PackageReference Include="System.Text.Json" />
3334
<PackageReference Include="xunit" />
3435
<PackageReference Include="xunit.runner.visualstudio" />

0 commit comments

Comments
 (0)