Skip to content

Commit a3e2a69

Browse files
committed
* Set cancellation correctly for TaskCompletionSource in AsyncRpcContinuation
1 parent f331ddf commit a3e2a69

File tree

4 files changed

+171
-67
lines changed

4 files changed

+171
-67
lines changed

projects/RabbitMQ.Client/ConsumerDispatching/ConsumerDispatcherChannelBase.cs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,23 @@ internal ConsumerDispatcherChannelBase(Impl.Channel channel, ushort concurrency)
8585

8686
public ValueTask HandleBasicConsumeOkAsync(IAsyncBasicConsumer consumer, string consumerTag, CancellationToken cancellationToken)
8787
{
88+
cancellationToken.ThrowIfCancellationRequested();
89+
8890
if (false == _disposedValue && false == _quiesce)
8991
{
90-
AddConsumer(consumer, consumerTag);
91-
WorkStruct work = WorkStruct.CreateConsumeOk(consumer, consumerTag);
92-
return _writer.WriteAsync(work, cancellationToken);
92+
try
93+
{
94+
AddConsumer(consumer, consumerTag);
95+
WorkStruct work = WorkStruct.CreateConsumeOk(consumer, consumerTag);
96+
97+
cancellationToken.ThrowIfCancellationRequested();
98+
return _writer.WriteAsync(work, cancellationToken);
99+
}
100+
catch
101+
{
102+
_ = GetAndRemoveConsumer(consumerTag);
103+
throw;
104+
}
93105
}
94106
else
95107
{
@@ -101,10 +113,14 @@ public ValueTask HandleBasicDeliverAsync(string consumerTag, ulong deliveryTag,
101113
string exchange, string routingKey, IReadOnlyBasicProperties basicProperties, RentedMemory body,
102114
CancellationToken cancellationToken)
103115
{
116+
cancellationToken.ThrowIfCancellationRequested();
117+
104118
if (false == _disposedValue && false == _quiesce)
105119
{
106120
IAsyncBasicConsumer consumer = GetConsumerOrDefault(consumerTag);
107121
var work = WorkStruct.CreateDeliver(consumer, consumerTag, deliveryTag, redelivered, exchange, routingKey, basicProperties, body);
122+
123+
cancellationToken.ThrowIfCancellationRequested();
108124
return _writer.WriteAsync(work, cancellationToken);
109125
}
110126
else
@@ -115,10 +131,14 @@ public ValueTask HandleBasicDeliverAsync(string consumerTag, ulong deliveryTag,
115131

116132
public ValueTask HandleBasicCancelOkAsync(string consumerTag, CancellationToken cancellationToken)
117133
{
134+
cancellationToken.ThrowIfCancellationRequested();
135+
118136
if (false == _disposedValue && false == _quiesce)
119137
{
120138
IAsyncBasicConsumer consumer = GetAndRemoveConsumer(consumerTag);
121139
WorkStruct work = WorkStruct.CreateCancelOk(consumer, consumerTag);
140+
141+
cancellationToken.ThrowIfCancellationRequested();
122142
return _writer.WriteAsync(work, cancellationToken);
123143
}
124144
else
@@ -129,10 +149,14 @@ public ValueTask HandleBasicCancelOkAsync(string consumerTag, CancellationToken
129149

130150
public ValueTask HandleBasicCancelAsync(string consumerTag, CancellationToken cancellationToken)
131151
{
152+
cancellationToken.ThrowIfCancellationRequested();
153+
132154
if (false == _disposedValue && false == _quiesce)
133155
{
134156
IAsyncBasicConsumer consumer = GetAndRemoveConsumer(consumerTag);
135157
WorkStruct work = WorkStruct.CreateCancel(consumer, consumerTag);
158+
159+
cancellationToken.ThrowIfCancellationRequested();
136160
return _writer.WriteAsync(work, cancellationToken);
137161
}
138162
else

projects/RabbitMQ.Client/Impl/AsyncRpcContinuations.cs

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ internal abstract class AsyncRpcContinuation<T> : IRpcContinuation
5151

5252
private bool _disposedValue;
5353

54-
public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken cancellationToken)
54+
public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken rpcCancellationToken)
5555
{
5656
/*
5757
* Note: we can't use an ObjectPool for these because the netstandard2.0
@@ -89,7 +89,7 @@ public AsyncRpcContinuation(TimeSpan continuationTimeout, CancellationToken canc
8989
_tcsConfiguredTaskAwaitable = _tcs.Task.ConfigureAwait(false);
9090

9191
_linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(
92-
_continuationTimeoutCancellationTokenSource.Token, cancellationToken);
92+
_continuationTimeoutCancellationTokenSource.Token, rpcCancellationToken);
9393
}
9494

9595
public CancellationToken CancellationToken
@@ -105,7 +105,27 @@ public ConfiguredTaskAwaitable<T>.ConfiguredTaskAwaiter GetAwaiter()
105105
return _tcsConfiguredTaskAwaitable.GetAwaiter();
106106
}
107107

108-
public abstract Task HandleCommandAsync(IncomingCommand cmd);
108+
public async Task HandleCommandAsync(IncomingCommand cmd)
109+
{
110+
try
111+
{
112+
await DoHandleCommandAsync(cmd)
113+
.ConfigureAwait(false);
114+
}
115+
catch (OperationCanceledException)
116+
{
117+
if (CancellationToken.IsCancellationRequested)
118+
{
119+
_tcs.SetCanceled();
120+
}
121+
else
122+
{
123+
throw;
124+
}
125+
}
126+
}
127+
128+
protected abstract Task DoHandleCommandAsync(IncomingCommand cmd);
109129

110130
public virtual void HandleChannelShutdown(ShutdownEventArgs reason)
111131
{
@@ -141,17 +161,17 @@ public ConnectionSecureOrTuneAsyncRpcContinuation(TimeSpan continuationTimeout,
141161
{
142162
}
143163

144-
public override Task HandleCommandAsync(IncomingCommand cmd)
164+
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
145165
{
146166
if (cmd.CommandId == ProtocolCommandId.ConnectionSecure)
147167
{
148168
var secure = new ConnectionSecure(cmd.MethodSpan);
149-
_tcs.TrySetResult(new ConnectionSecureOrTune(secure._challenge, default));
169+
_tcs.SetResult(new ConnectionSecureOrTune(secure._challenge, default));
150170
}
151171
else if (cmd.CommandId == ProtocolCommandId.ConnectionTune)
152172
{
153173
var tune = new ConnectionTune(cmd.MethodSpan);
154-
_tcs.TrySetResult(new ConnectionSecureOrTune(default, new ConnectionTuneDetails
174+
_tcs.SetResult(new ConnectionSecureOrTune(default, new ConnectionTuneDetails
155175
{
156176
m_channelMax = tune._channelMax,
157177
m_frameMax = tune._frameMax,
@@ -178,11 +198,11 @@ public SimpleAsyncRpcContinuation(ProtocolCommandId expectedCommandId, TimeSpan
178198
_expectedCommandId = expectedCommandId;
179199
}
180200

181-
public override Task HandleCommandAsync(IncomingCommand cmd)
201+
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
182202
{
183203
if (cmd.CommandId == _expectedCommandId)
184204
{
185-
_tcs.TrySetResult(true);
205+
_tcs.SetResult(true);
186206
}
187207
else
188208
{
@@ -206,14 +226,14 @@ public BasicCancelAsyncRpcContinuation(string consumerTag, IConsumerDispatcher c
206226
_consumerDispatcher = consumerDispatcher;
207227
}
208228

209-
public override async Task HandleCommandAsync(IncomingCommand cmd)
229+
protected override async Task DoHandleCommandAsync(IncomingCommand cmd)
210230
{
211231
if (cmd.CommandId == ProtocolCommandId.BasicCancelOk)
212232
{
213-
_tcs.TrySetResult(true);
214233
Debug.Assert(_consumerTag == new BasicCancelOk(cmd.MethodSpan)._consumerTag);
215234
await _consumerDispatcher.HandleBasicCancelOkAsync(_consumerTag, CancellationToken)
216235
.ConfigureAwait(false);
236+
_tcs.SetResult(true);
217237
}
218238
else
219239
{
@@ -235,14 +255,16 @@ public BasicConsumeAsyncRpcContinuation(IAsyncBasicConsumer consumer, IConsumerD
235255
_consumerDispatcher = consumerDispatcher;
236256
}
237257

238-
public override async Task HandleCommandAsync(IncomingCommand cmd)
258+
protected override async Task DoHandleCommandAsync(IncomingCommand cmd)
239259
{
240260
if (cmd.CommandId == ProtocolCommandId.BasicConsumeOk)
241261
{
242262
var method = new BasicConsumeOk(cmd.MethodSpan);
243-
_tcs.TrySetResult(method._consumerTag);
263+
244264
await _consumerDispatcher.HandleBasicConsumeOkAsync(_consumer, method._consumerTag, CancellationToken)
245265
.ConfigureAwait(false);
266+
267+
_tcs.SetResult(method._consumerTag);
246268
}
247269
else
248270
{
@@ -264,7 +286,7 @@ public BasicGetAsyncRpcContinuation(Func<ulong, ulong> adjustDeliveryTag,
264286

265287
internal DateTime StartTime { get; } = DateTime.UtcNow;
266288

267-
public override Task HandleCommandAsync(IncomingCommand cmd)
289+
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
268290
{
269291
if (cmd.CommandId == ProtocolCommandId.BasicGetOk)
270292
{
@@ -280,11 +302,11 @@ public override Task HandleCommandAsync(IncomingCommand cmd)
280302
header,
281303
cmd.Body.ToArray());
282304

283-
_tcs.TrySetResult(result);
305+
_tcs.SetResult(result);
284306
}
285307
else if (cmd.CommandId == ProtocolCommandId.BasicGetEmpty)
286308
{
287-
_tcs.TrySetResult(null);
309+
_tcs.SetResult(null);
288310
}
289311
else
290312
{
@@ -325,7 +347,7 @@ public override void HandleChannelShutdown(ShutdownEventArgs reason)
325347

326348
public Task OnConnectionShutdownAsync(object? sender, ShutdownEventArgs reason)
327349
{
328-
_tcs.TrySetResult(true);
350+
_tcs.SetResult(true);
329351
return Task.CompletedTask;
330352
}
331353
}
@@ -377,13 +399,13 @@ public QueueDeclareAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellati
377399
{
378400
}
379401

380-
public override Task HandleCommandAsync(IncomingCommand cmd)
402+
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
381403
{
382404
if (cmd.CommandId == ProtocolCommandId.QueueDeclareOk)
383405
{
384406
var method = new Client.Framing.QueueDeclareOk(cmd.MethodSpan);
385407
var result = new QueueDeclareOk(method._queue, method._messageCount, method._consumerCount);
386-
_tcs.TrySetResult(result);
408+
_tcs.SetResult(result);
387409
}
388410
else
389411
{
@@ -417,12 +439,12 @@ public QueueDeleteAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellatio
417439
{
418440
}
419441

420-
public override Task HandleCommandAsync(IncomingCommand cmd)
442+
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
421443
{
422444
if (cmd.CommandId == ProtocolCommandId.QueueDeleteOk)
423445
{
424446
var method = new QueueDeleteOk(cmd.MethodSpan);
425-
_tcs.TrySetResult(method._messageCount);
447+
_tcs.SetResult(method._messageCount);
426448
}
427449
else
428450
{
@@ -440,12 +462,12 @@ public QueuePurgeAsyncRpcContinuation(TimeSpan continuationTimeout, Cancellation
440462
{
441463
}
442464

443-
public override Task HandleCommandAsync(IncomingCommand cmd)
465+
protected override Task DoHandleCommandAsync(IncomingCommand cmd)
444466
{
445467
if (cmd.CommandId == ProtocolCommandId.QueuePurgeOk)
446468
{
447469
var method = new QueuePurgeOk(cmd.MethodSpan);
448-
_tcs.TrySetResult(method._messageCount);
470+
_tcs.SetResult(method._messageCount);
449471
}
450472
else
451473
{
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// This source code is dual-licensed under the Apache License, version
2+
// 2.0, and the Mozilla Public License, version 2.0.
3+
//
4+
// The APL v2.0:
5+
//
6+
//---------------------------------------------------------------------------
7+
// Copyright (c) 2007-2024 Broadcom. All Rights Reserved.
8+
//
9+
// Licensed under the Apache License, Version 2.0 (the "License");
10+
// you may not use this file except in compliance with the License.
11+
// You may obtain a copy of the License at
12+
//
13+
// https://www.apache.org/licenses/LICENSE-2.0
14+
//
15+
// Unless required by applicable law or agreed to in writing, software
16+
// distributed under the License is distributed on an "AS IS" BASIS,
17+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
// See the License for the specific language governing permissions and
19+
// limitations under the License.
20+
//---------------------------------------------------------------------------
21+
//
22+
// The MPL v2.0:
23+
//
24+
//---------------------------------------------------------------------------
25+
// This Source Code Form is subject to the terms of the Mozilla Public
26+
// License, v. 2.0. If a copy of the MPL was not distributed with this
27+
// file, You can obtain one at https://mozilla.org/MPL/2.0/.
28+
//
29+
// Copyright (c) 2007-2024 Broadcom. All Rights Reserved.
30+
//---------------------------------------------------------------------------
31+
32+
using System.Threading;
33+
using System.Threading.Tasks;
34+
using RabbitMQ.Client;
35+
using RabbitMQ.Client.Events;
36+
using Xunit;
37+
using Xunit.Abstractions;
38+
39+
#nullable enable
40+
41+
namespace Test.Integration.GH
42+
{
43+
public class TestGitHubIssues : IntegrationFixture
44+
{
45+
public TestGitHubIssues(ITestOutputHelper output) : base(output)
46+
{
47+
}
48+
49+
public override Task InitializeAsync()
50+
{
51+
// NB: nothing to do here since each test creates its own factory,
52+
// connections and channels
53+
Assert.Null(_connFactory);
54+
Assert.Null(_conn);
55+
Assert.Null(_channel);
56+
return Task.CompletedTask;
57+
}
58+
59+
[Fact]
60+
public async Task TestBasicConsumeCancellation_GH1750()
61+
{
62+
/*
63+
* Note:
64+
* Testing that the task is actually canceled requires a hacked RabbitMQ server.
65+
* Modify deps/rabbit/src/rabbit_channel.erl, handle_cast for basic.consume_ok
66+
* Before send/2, add timer:sleep(1000), then `make run-broker`
67+
*
68+
* The _output line at the end of the test will print TaskCanceledException
69+
*/
70+
Assert.Null(_connFactory);
71+
Assert.Null(_conn);
72+
Assert.Null(_channel);
73+
74+
_connFactory = CreateConnectionFactory();
75+
_connFactory.AutomaticRecoveryEnabled = false;
76+
_connFactory.TopologyRecoveryEnabled = false;
77+
78+
_conn = await _connFactory.CreateConnectionAsync();
79+
_channel = await _conn.CreateChannelAsync();
80+
81+
QueueDeclareOk q = await _channel.QueueDeclareAsync();
82+
83+
var consumer = new AsyncEventingBasicConsumer(_channel);
84+
consumer.ReceivedAsync += (o, a) =>
85+
{
86+
return Task.CompletedTask;
87+
};
88+
89+
try
90+
{
91+
using var cts = new CancellationTokenSource(5);
92+
await _channel.BasicConsumeAsync(q.QueueName, true, consumer, cts.Token);
93+
}
94+
catch (TaskCanceledException ex)
95+
{
96+
_output.WriteLine("ex: {0}", ex);
97+
}
98+
}
99+
}
100+
}

0 commit comments

Comments
 (0)