Skip to content

Commit ea0a5a7

Browse files
authored
Merge pull request #966 from jkarder/fix-async-cancellation
Fix cancellation in ReceiveFrameBytesAsync
2 parents 10a263b + d297a27 commit ea0a5a7

File tree

1 file changed

+64
-19
lines changed

1 file changed

+64
-19
lines changed

src/NetMQ/AsyncReceiveExtensions.cs

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ public static class AsyncReceiveExtensions
2020
static Task<bool> s_trueTask = Task.FromResult(true);
2121
static Task<bool> s_falseTask = Task.FromResult(false);
2222

23-
#region Receiving frames as a multipart message
23+
#region Receiving frames as a multipart message
2424

2525
/// <summary>
2626
/// Receive a single frame from <paramref name="socket"/>, asynchronously.
@@ -52,9 +52,9 @@ public static async Task<NetMQMessage> ReceiveMultipartMessageAsync(
5252
return message;
5353
}
5454

55-
#endregion
55+
#endregion
5656

57-
#region Receiving a frame as a byte array
57+
#region Receiving a frame as a byte array
5858

5959
/// <summary>
6060
/// Receive a single frame from <paramref name="socket"/>, asynchronously.
@@ -85,7 +85,12 @@ public static async Task<NetMQMessage> ReceiveMultipartMessageAsync(
8585
}
8686

8787
TaskCompletionSource<(byte[], bool)> source = new TaskCompletionSource<(byte[], bool)>();
88-
cancellationToken.Register(() => source.SetCanceled());
88+
89+
CancellationTokenRegistration? registration = null;
90+
if (cancellationToken.CanBeCanceled)
91+
{
92+
registration = cancellationToken.Register(PropagateCancel);
93+
}
8994

9095
void Listener(object sender, NetMQSocketEventArgs args)
9196
{
@@ -96,18 +101,26 @@ void Listener(object sender, NetMQSocketEventArgs args)
96101
msg.Close();
97102

98103
socket.ReceiveReady -= Listener;
99-
source.SetResult((data, more));
104+
registration?.Dispose();
105+
source.TrySetResult((data, more));
100106
}
101107
}
108+
109+
void PropagateCancel()
110+
{
111+
socket.ReceiveReady -= Listener;
112+
registration?.Dispose();
113+
source.TrySetCanceled();
114+
}
102115

103116
socket.ReceiveReady += Listener;
104117

105118
return source.Task;
106119
}
107120

108-
#endregion
121+
#endregion
109122

110-
#region Receiving a frame as a string
123+
#region Receiving a frame as a string
111124

112125
/// <summary>
113126
/// Receive a single frame from <paramref name="socket"/>, asynchronously, and decode as a string using <see cref="SendReceiveConstants.DefaultEncoding"/>.
@@ -155,7 +168,12 @@ void Listener(object sender, NetMQSocketEventArgs args)
155168
}
156169

157170
TaskCompletionSource<(string, bool)> source = new TaskCompletionSource<(string,bool)>();
158-
cancellationToken.Register(() => source.SetCanceled());
171+
172+
CancellationTokenRegistration? registration = null;
173+
if (cancellationToken.CanBeCanceled)
174+
{
175+
registration = cancellationToken.Register(PropagateCancel);
176+
}
159177

160178
void Listener(object sender, NetMQSocketEventArgs args)
161179
{
@@ -165,28 +183,40 @@ void Listener(object sender, NetMQSocketEventArgs args)
165183
? msg.GetString(encoding)
166184
: string.Empty;
167185
bool more = msg.HasMore;
168-
169186
msg.Close();
187+
170188
socket.ReceiveReady -= Listener;
171-
source.SetResult((str, more));
189+
registration?.Dispose();
190+
source.TrySetResult((str, more));
172191
}
173192
}
174193

194+
void PropagateCancel()
195+
{
196+
socket.ReceiveReady -= Listener;
197+
registration?.Dispose();
198+
source.TrySetCanceled();
199+
}
200+
175201
socket.ReceiveReady += Listener;
176202

177203
return source.Task;
178204
}
179205

180-
#endregion
206+
#endregion
181207

182-
#region Skipping a message
208+
#region Skipping a message
183209

184210
/// <summary>
185211
/// Receive a single frame from <paramref name="socket"/>, asynchronously, then ignore its content.
186212
/// </summary>
187213
/// <param name="socket">The socket to receive from.</param>
214+
/// <param name="cancellationToken">The token used to propagate notification that this operation should be canceled.</param>
188215
/// <returns>Boolean indicate if another frame of the same message follows</returns>
189-
public static Task<bool> SkipFrameAsync(this NetMQSocket socket)
216+
public static Task<bool> SkipFrameAsync(
217+
this NetMQSocket socket,
218+
CancellationToken cancellationToken = default(CancellationToken)
219+
)
190220
{
191221
if (NetMQRuntime.Current == null)
192222
throw new InvalidOperationException("NetMQRuntime must be created before calling async functions");
@@ -206,26 +236,41 @@ public static Task<bool> SkipFrameAsync(this NetMQSocket socket)
206236

207237
TaskCompletionSource<bool> source = new TaskCompletionSource<bool>();
208238

239+
CancellationTokenRegistration? registration = null;
240+
if (cancellationToken.CanBeCanceled)
241+
{
242+
registration = cancellationToken.Register(PropagateCancel);
243+
}
244+
209245
void Listener(object sender, NetMQSocketEventArgs args)
210246
{
211247
if (socket.TryReceive(ref msg, TimeSpan.Zero))
212248
{
213249
bool more = msg.HasMore;
214250
msg.Close();
251+
215252
socket.ReceiveReady -= Listener;
216-
source.SetResult(more);
253+
registration?.Dispose();
254+
source.TrySetResult(more);
217255
}
218256
}
219257

258+
void PropagateCancel()
259+
{
260+
socket.ReceiveReady -= Listener;
261+
registration?.Dispose();
262+
source.TrySetCanceled();
263+
}
264+
220265
socket.ReceiveReady += Listener;
221266

222267
return source.Task;
223268
}
224269

225270

226-
#endregion
271+
#endregion
227272

228-
#region Skipping all frames of a multipart message
273+
#region Skipping all frames of a multipart message
229274

230275
/// <summary>
231276
/// Receive all frames of the next message from <paramref name="socket"/>, asynchronously, then ignore their contents.
@@ -242,9 +287,9 @@ public static async Task SkipMultipartMessageAsync(this NetMQSocket socket)
242287
}
243288

244289

245-
#endregion
290+
#endregion
246291

247-
#region Receiving a routing key
292+
#region Receiving a routing key
248293

249294
/// <summary>
250295
/// Receive a routing-key from <paramref name="socket"/>, blocking until one arrives.
@@ -259,7 +304,7 @@ public static async Task SkipMultipartMessageAsync(this NetMQSocket socket)
259304
return (new RoutingKey(bytes), more);
260305
}
261306

262-
#endregion
307+
#endregion
263308
}
264309
}
265310

0 commit comments

Comments
 (0)