Skip to content

Commit f5bf954

Browse files
saknolukebakken
authored andcommitted
Added ability to specify custom ArrayPool
Review feedback Added MemoryPool property to approved API Prefer a ctor argument to specify your own ArrayPool<byte>
1 parent 0c2b5f8 commit f5bf954

17 files changed

+112
-45
lines changed

projects/RabbitMQ.Client/client/api/ConnectionFactory.cs

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131

3232
using System;
3333
using System.Collections.Generic;
34+
using System.Buffers;
3435
using System.Linq;
3536
using System.Net.Security;
3637
using System.Security.Authentication;
@@ -188,6 +189,15 @@ public sealed class ConnectionFactory : ConnectionFactoryBase, IAsyncConnectionF
188189

189190
// just here to hold the value that was set through the setter
190191
private Uri _uri;
192+
private readonly ArrayPool<byte> _memoryPool;
193+
194+
/// <summary>
195+
/// The memory pool used for allocating buffers. Default is <see cref="MemoryPool{T}.Shared"/>.
196+
/// </summary>
197+
public ArrayPool<byte> MemoryPool
198+
{
199+
get => _memoryPool;
200+
}
191201

192202
/// <summary>
193203
/// Amount of time protocol handshake operations are allowed to take before
@@ -258,6 +268,18 @@ public TimeSpan ContinuationTimeout
258268
public ConnectionFactory()
259269
{
260270
ClientProperties = Connection.DefaultClientProperties();
271+
_memoryPool = ArrayPool<byte>.Shared;
272+
}
273+
274+
/// <summary>
275+
/// Construct a fresh instance, with all fields set to their respective defaults,
276+
/// using your own memory pool.
277+
/// <param name="memoryPool">Memory pool to use with all Connections</param>
278+
/// </summary>
279+
public ConnectionFactory(ArrayPool<byte> memoryPool)
280+
{
281+
ClientProperties = Connection.DefaultClientProperties();
282+
_memoryPool = memoryPool;
261283
}
262284

263285
/// <summary>
@@ -497,7 +519,8 @@ public IConnection CreateConnection(IEndpointResolver endpointResolver, string c
497519
else
498520
{
499521
var protocol = new RabbitMQ.Client.Framing.Protocol();
500-
conn = protocol.CreateConnection(this, false, endpointResolver.SelectOne(CreateFrameHandler), clientProvidedName);
522+
conn = protocol.CreateConnection(this, false, endpointResolver.SelectOne(CreateFrameHandler),
523+
_memoryPool, clientProvidedName);
501524
}
502525
}
503526
catch (Exception e)
@@ -510,7 +533,7 @@ public IConnection CreateConnection(IEndpointResolver endpointResolver, string c
510533

511534
internal IFrameHandler CreateFrameHandler(AmqpTcpEndpoint endpoint)
512535
{
513-
IFrameHandler fh = Protocols.DefaultProtocol.CreateFrameHandler(endpoint, SocketFactory,
536+
IFrameHandler fh = Protocols.DefaultProtocol.CreateFrameHandler(endpoint, _memoryPool, SocketFactory,
514537
RequestedConnectionTimeout, SocketReadTimeout, SocketWriteTimeout);
515538
return ConfigureFrameHandler(fh);
516539
}

projects/RabbitMQ.Client/client/impl/AsyncConsumerDispatcher.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,11 @@ public void HandleBasicDeliver(IBasicConsumer consumer,
4747
IBasicProperties basicProperties,
4848
ReadOnlySpan<byte> body)
4949
{
50-
byte[] bodyBytes = ArrayPool<byte>.Shared.Rent(body.Length);
50+
var pool = _model.Session.Connection.MemoryPool;
51+
byte[] bodyBytes = pool.Rent(body.Length);
5152
Memory<byte> bodyCopy = new Memory<byte>(bodyBytes, 0, body.Length);
5253
body.CopyTo(bodyCopy.Span);
53-
ScheduleUnlessShuttingDown(new BasicDeliver(consumer, consumerTag, deliveryTag, redelivered, exchange, routingKey, basicProperties, bodyCopy));
54+
ScheduleUnlessShuttingDown(new BasicDeliver(consumer, consumerTag, deliveryTag, redelivered, exchange, routingKey, basicProperties, bodyCopy, pool));
5455
}
5556

5657
public void HandleBasicCancelOk(IBasicConsumer consumer, string consumerTag)

projects/RabbitMQ.Client/client/impl/AutorecoveringConnection.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,7 @@ private void Init(IFrameHandler fh)
683683
throw new ObjectDisposedException(GetType().FullName);
684684
}
685685

686-
_delegate = new Connection(_factory, false,
687-
fh, ClientProvidedName);
686+
_delegate = new Connection(_factory, false, fh, _factory.MemoryPool, ClientProvidedName);
688687

689688
_recoveryTask = Task.Run(MainRecoveryLoop);
690689

@@ -1017,7 +1016,7 @@ private bool TryRecoverConnectionDelegate()
10171016
try
10181017
{
10191018
IFrameHandler fh = _endpoints.SelectOne(_factory.CreateFrameHandler);
1020-
_delegate = new Connection(_factory, false, fh, ClientProvidedName);
1019+
_delegate = new Connection(_factory, false, fh, _factory.MemoryPool, ClientProvidedName);
10211020
return true;
10221021
}
10231022
catch (Exception e)

projects/RabbitMQ.Client/client/impl/BasicDeliver.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ internal sealed class BasicDeliver : Work
1414
private readonly string _routingKey;
1515
private readonly IBasicProperties _basicProperties;
1616
private readonly ReadOnlyMemory<byte> _body;
17+
private readonly ArrayPool<byte> _bodyOwner;
1718

1819
public override string Context => "HandleBasicDeliver";
1920

@@ -24,7 +25,8 @@ public BasicDeliver(IBasicConsumer consumer,
2425
string exchange,
2526
string routingKey,
2627
IBasicProperties basicProperties,
27-
ReadOnlyMemory<byte> body) : base(consumer)
28+
ReadOnlyMemory<byte> body,
29+
ArrayPool<byte> pool) : base(consumer)
2830
{
2931
_consumerTag = consumerTag;
3032
_deliveryTag = deliveryTag;
@@ -33,6 +35,7 @@ public BasicDeliver(IBasicConsumer consumer,
3335
_routingKey = routingKey;
3436
_basicProperties = basicProperties;
3537
_body = body;
38+
_bodyOwner = pool;
3639
}
3740

3841
protected override Task Execute(IAsyncBasicConsumer consumer)
@@ -50,7 +53,7 @@ public override void PostExecute()
5053
{
5154
if (MemoryMarshal.TryGetArray(_body, out ArraySegment<byte> segment))
5255
{
53-
ArrayPool<byte>.Shared.Return(segment.Array);
56+
_bodyOwner.Return(segment.Array);
5457
}
5558
}
5659
}

projects/RabbitMQ.Client/client/impl/CommandAssembler.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ public IncomingCommand HandleFrame(in InboundFrame frame)
8888
return IncomingCommand.Empty;
8989
}
9090

91-
var result = new IncomingCommand(_method, _header, _body, _bodyBytes);
91+
var result = new IncomingCommand(_method, _header, _body, _bodyBytes, _protocol.MemoryPool);
9292
Reset();
9393
return result;
9494
}
@@ -123,7 +123,7 @@ private void ParseHeaderFrame(in InboundFrame frame)
123123
_remainingBodyBytes = (int) totalBodyBytes;
124124

125125
// Is returned by IncomingCommand.Dispose in Session.HandleFrame
126-
_bodyBytes = ArrayPool<byte>.Shared.Rent(_remainingBodyBytes);
126+
_bodyBytes = _protocol.MemoryPool.Rent(_remainingBodyBytes);
127127
_body = new Memory<byte>(_bodyBytes, 0, _remainingBodyBytes);
128128
UpdateContentBodyState();
129129
}

projects/RabbitMQ.Client/client/impl/ConcurrentConsumerDispatcher.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ public void HandleBasicDeliver(IBasicConsumer consumer,
6464
IBasicProperties basicProperties,
6565
ReadOnlySpan<byte> body)
6666
{
67-
byte[] memoryCopyArray = ArrayPool<byte>.Shared.Rent(body.Length);
67+
var pool = _model.Session.Connection.MemoryPool;
68+
byte[] memoryCopyArray = pool.Rent(body.Length);
6869
Memory<byte> memoryCopy = new Memory<byte>(memoryCopyArray, 0, body.Length);
6970
body.CopyTo(memoryCopy.Span);
7071
UnlessShuttingDown(() =>
@@ -90,7 +91,7 @@ public void HandleBasicDeliver(IBasicConsumer consumer,
9091
}
9192
finally
9293
{
93-
ArrayPool<byte>.Shared.Return(memoryCopyArray);
94+
pool.Return(memoryCopyArray);
9495
}
9596
});
9697
}

projects/RabbitMQ.Client/client/impl/Connection.cs

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ internal sealed class Connection : IConnection
7070
private volatile bool _running = true;
7171
private readonly MainSession _session0;
7272
private SessionManager _sessionManager;
73+
private readonly ArrayPool<byte> _memoryPool = ArrayPool<byte>.Shared;
7374

7475
//
7576
// Heartbeats
@@ -127,6 +128,18 @@ public Connection(IConnectionFactory factory, bool insist, IFrameHandler frameHa
127128
}
128129
}
129130

131+
public Connection(IConnectionFactory factory, bool insist, IFrameHandler frameHandler, ArrayPool<byte> memoryPool,
132+
string clientProvidedName = null)
133+
: this(factory, insist, frameHandler, clientProvidedName)
134+
{
135+
_memoryPool = memoryPool;
136+
}
137+
138+
internal ArrayPool<byte> MemoryPool
139+
{
140+
get => _memoryPool;
141+
}
142+
130143
public Guid Id { get { return _id; } }
131144

132145
public event EventHandler<CallbackExceptionEventArgs> CallbackException;
@@ -908,7 +921,7 @@ public void HeartbeatWriteTimerCallback(object state)
908921
{
909922
if (!_closed)
910923
{
911-
Write(Client.Impl.Framing.Heartbeat.GetHeartbeatFrame());
924+
Write(Client.Impl.Framing.Heartbeat.GetHeartbeatFrame(MemoryPool));
912925
_heartbeatWriteTimer?.Change((int)_heartbeatTimeSpan.TotalMilliseconds, Timeout.Infinite);
913926
}
914927
}
@@ -945,7 +958,7 @@ public override string ToString()
945958
return string.Format("Connection({0},{1})", _id, Endpoint);
946959
}
947960

948-
public void Write(Memory<byte> memory)
961+
public void Write(ReadOnlyMemory<byte> memory)
949962
{
950963
_frameHandler.Write(memory);
951964
}

projects/RabbitMQ.Client/client/impl/Frame.cs

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -147,12 +147,12 @@ internal static class Heartbeat
147147
Constants.FrameEnd
148148
};
149149

150-
public static Memory<byte> GetHeartbeatFrame()
150+
public static ReadOnlyMemory<byte> GetHeartbeatFrame(ArrayPool<byte> pool)
151151
{
152152
// Is returned by SocketFrameHandler.WriteLoop
153-
var buffer = ArrayPool<byte>.Shared.Rent(FrameSize);
153+
var buffer = pool.Rent(FrameSize);
154154
Payload.CopyTo(buffer);
155-
return new Memory<byte>(buffer, 0, FrameSize);
155+
return new ReadOnlyMemory<byte>(buffer, 0, FrameSize);
156156
}
157157
}
158158
}
@@ -163,13 +163,15 @@ public static Memory<byte> GetHeartbeatFrame()
163163
public readonly int Channel;
164164
public readonly ReadOnlyMemory<byte> Payload;
165165
private readonly byte[] _rentedArray;
166+
private readonly ArrayPool<byte> _rentedArrayOwner;
166167

167-
private InboundFrame(FrameType type, int channel, ReadOnlyMemory<byte> payload, byte[] rentedArray)
168+
private InboundFrame(FrameType type, int channel, ReadOnlyMemory<byte> payload, byte[] rentedArray, ArrayPool<byte> rentedArrayOwner)
168169
{
169170
Type = type;
170171
Channel = channel;
171172
Payload = payload;
172173
_rentedArray = rentedArray;
174+
_rentedArrayOwner = rentedArrayOwner;
173175
}
174176

175177
private static void ProcessProtocolHeader(Stream reader)
@@ -203,7 +205,7 @@ private static void ProcessProtocolHeader(Stream reader)
203205
}
204206
}
205207

206-
internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer)
208+
internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer, ArrayPool<byte> pool)
207209
{
208210
int type = default;
209211
try
@@ -242,7 +244,7 @@ internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer)
242244
const int EndMarkerLength = 1;
243245
// Is returned by InboundFrame.Dispose in Connection.MainLoopIteration
244246
var readSize = payloadSize + EndMarkerLength;
245-
byte[] payloadBytes = ArrayPool<byte>.Shared.Rent(readSize);
247+
byte[] payloadBytes = pool.Rent(readSize);
246248
int bytesRead = 0;
247249
try
248250
{
@@ -254,22 +256,22 @@ internal static InboundFrame ReadFrom(Stream reader, byte[] frameHeaderBuffer)
254256
catch (Exception)
255257
{
256258
// Early EOF.
257-
ArrayPool<byte>.Shared.Return(payloadBytes);
259+
pool.Return(payloadBytes);
258260
throw new MalformedFrameException($"Short frame - expected to read {readSize} bytes, only got {bytesRead} bytes");
259261
}
260262

261263
if (payloadBytes[payloadSize] != Constants.FrameEnd)
262264
{
263-
ArrayPool<byte>.Shared.Return(payloadBytes);
265+
pool.Return(payloadBytes);
264266
throw new MalformedFrameException($"Bad frame end marker: {payloadBytes[payloadSize]}");
265267
}
266268

267-
return new InboundFrame((FrameType)type, channel, new Memory<byte>(payloadBytes, 0, payloadSize), payloadBytes);
269+
return new InboundFrame((FrameType)type, channel, new Memory<byte>(payloadBytes, 0, payloadSize), payloadBytes, pool);
268270
}
269271

270272
public void Dispose()
271273
{
272-
ArrayPool<byte>.Shared.Return(_rentedArray);
274+
_rentedArrayOwner.Return(_rentedArray);
273275
}
274276

275277
public override string ToString()

projects/RabbitMQ.Client/client/impl/IFrameHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,6 @@ interface IFrameHandler
6161

6262
void SendHeader();
6363

64-
void Write(Memory<byte> memory);
64+
void Write(ReadOnlyMemory<byte> memory);
6565
}
6666
}

projects/RabbitMQ.Client/client/impl/IProtocolExtensions.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,8 @@
3030
//---------------------------------------------------------------------------
3131

3232
using System;
33-
33+
using System.Buffers;
3434
using System.Net.Sockets;
35-
3635
using RabbitMQ.Client.Impl;
3736

3837
namespace RabbitMQ.Client.Framing.Impl
@@ -42,12 +41,16 @@ static class IProtocolExtensions
4241
public static IFrameHandler CreateFrameHandler(
4342
this IProtocol protocol,
4443
AmqpTcpEndpoint endpoint,
44+
ArrayPool<byte> pool,
4545
Func<AddressFamily, ITcpClient> socketFactory,
4646
TimeSpan connectionTimeout,
4747
TimeSpan readTimeout,
4848
TimeSpan writeTimeout)
4949
{
50-
return new SocketFrameHandler(endpoint, socketFactory, connectionTimeout, readTimeout, writeTimeout);
50+
return new SocketFrameHandler(endpoint, socketFactory, connectionTimeout, readTimeout, writeTimeout)
51+
{
52+
MemoryPool = pool
53+
};
5154
}
5255
}
5356
}

0 commit comments

Comments
 (0)