Skip to content

Commit 983792a

Browse files
committed
Replace AwaitableSocketAsyncEventArgs in SocketExtensions
The existing AwaitableSocketAsyncEventArgs is useful in principal for being reusable in order to save on allocations. However, we don't reuse it and the implementation is flawed. Instead, use implementations based on TaskCompletionSource, and add a SendAsync method. Because sockets are only natively cancellable on modern .NET, I was torn between 3 options for cancellation on the targets which use SocketExtensions: 1. Do not respect the CancellationToken once the socket operation has started. I believe this is what earlier versions of .NET Core did when CancellationToken overloads were first added via SocketTaskExtensions. 2. Do not close the socket upon cancellation, meaning the socket operation continues to run after the Task has completed. This is what the previous implementation effectively does. 3. Close the socket when the CancellationToken is cancelled, in order to stop the socket operation. The behaviour of a socket after (proper) cancellation is undefined(?), so in any case it should not make sense to use the socket after triggering cancellation. I felt that option 2 was the worst of them. This iteration goes for option 3.
1 parent 0876e88 commit 983792a

File tree

2 files changed

+79
-87
lines changed

2 files changed

+79
-87
lines changed

src/Renci.SshNet/Abstractions/SocketAbstraction.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,9 +312,9 @@ public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeS
312312
}
313313

314314
#if NET6_0_OR_GREATER == false
315-
public static Task<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
315+
public static ValueTask<int> ReadAsync(Socket socket, byte[] buffer, CancellationToken cancellationToken)
316316
{
317-
return socket.ReceiveAsync(buffer, 0, buffer.Length, cancellationToken);
317+
return socket.ReceiveAsync(new ArraySegment<byte>(buffer, 0, buffer.Length), SocketFlags.None, cancellationToken);
318318
}
319319
#endif
320320

Lines changed: 77 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1,134 +1,126 @@
1-
#if !NET6_0_OR_GREATER
1+
#if !NET
2+
#if NETFRAMEWORK || NETSTANDARD2_0
23
using System;
4+
#endif
35
using System.Net;
46
using System.Net.Sockets;
5-
using System.Runtime.CompilerServices;
67
using System.Threading;
78
using System.Threading.Tasks;
89

910
namespace Renci.SshNet.Abstractions
1011
{
11-
// Async helpers based on https://devblogs.microsoft.com/pfxteam/awaiting-socket-operations/
1212
internal static class SocketExtensions
1313
{
14-
private sealed class AwaitableSocketAsyncEventArgs : SocketAsyncEventArgs, INotifyCompletion
14+
public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
1515
{
16-
private static readonly Action SENTINEL = () => { };
16+
cancellationToken.ThrowIfCancellationRequested();
1717

18-
private bool _isCancelled;
19-
private Action _continuationAction;
18+
TaskCompletionSource<object> tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
2019

21-
public AwaitableSocketAsyncEventArgs()
20+
using var args = new SocketAsyncEventArgs
2221
{
23-
Completed += (sender, e) => SetCompleted();
24-
}
22+
RemoteEndPoint = remoteEndpoint
23+
};
24+
args.Completed += (_, _) => tcs.TrySetResult(null);
2525

26-
public AwaitableSocketAsyncEventArgs ExecuteAsync(Func<SocketAsyncEventArgs, bool> func)
26+
if (socket.ConnectAsync(args))
2727
{
28-
if (!func(this))
28+
#if NETSTANDARD2_1
29+
await using (cancellationToken.Register(() =>
30+
#else
31+
using (cancellationToken.Register(() =>
32+
#endif
2933
{
30-
SetCompleted();
31-
}
32-
33-
return this;
34-
}
35-
36-
public void SetCompleted()
37-
{
38-
IsCompleted = true;
39-
40-
var continuation = _continuationAction ?? Interlocked.CompareExchange(ref _continuationAction, SENTINEL, comparand: null);
41-
if (continuation is not null)
34+
if (tcs.TrySetCanceled(cancellationToken))
35+
{
36+
socket.Dispose();
37+
}
38+
},
39+
useSynchronizationContext: false)
40+
#if NETSTANDARD2_1
41+
.ConfigureAwait(false)
42+
#endif
43+
)
4244
{
43-
continuation();
45+
_ = await tcs.Task.ConfigureAwait(false);
4446
}
4547
}
4648

47-
public void SetCancelled()
49+
if (args.SocketError != SocketError.Success)
4850
{
49-
_isCancelled = true;
50-
SetCompleted();
51+
throw new SocketException((int) args.SocketError);
5152
}
53+
}
5254

53-
#pragma warning disable S1144 // Unused private types or members should be removed
54-
public AwaitableSocketAsyncEventArgs GetAwaiter()
55-
#pragma warning restore S1144 // Unused private types or members should be removed
56-
{
57-
return this;
58-
}
55+
#if NETFRAMEWORK || NETSTANDARD2_0
56+
public static async ValueTask<int> ReceiveAsync(this Socket socket, ArraySegment<byte> buffer, SocketFlags socketFlags, CancellationToken cancellationToken)
57+
{
58+
cancellationToken.ThrowIfCancellationRequested();
5959

60-
public bool IsCompleted { get; private set; }
60+
TaskCompletionSource<object> tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
6161

62-
void INotifyCompletion.OnCompleted(Action continuation)
63-
{
64-
if (_continuationAction == SENTINEL || Interlocked.CompareExchange(ref _continuationAction, continuation, comparand: null) == SENTINEL)
65-
{
66-
// We have already completed; run continuation asynchronously
67-
_ = Task.Run(continuation);
68-
}
69-
}
62+
using var args = new SocketAsyncEventArgs();
63+
args.SocketFlags = socketFlags;
64+
args.Completed += (_, _) => tcs.TrySetResult(null);
65+
args.SetBuffer(buffer.Array, buffer.Offset, buffer.Count);
7066

71-
#pragma warning disable S1144 // Unused private types or members should be removed
72-
public void GetResult()
73-
#pragma warning restore S1144 // Unused private types or members should be removed
67+
if (socket.ReceiveAsync(args))
7468
{
75-
if (_isCancelled)
69+
using (cancellationToken.Register(() =>
7670
{
77-
throw new TaskCanceledException();
78-
}
79-
80-
if (!IsCompleted)
81-
{
82-
// We don't support sync/async
83-
throw new InvalidOperationException("The asynchronous operation has not yet completed.");
84-
}
85-
86-
if (SocketError != SocketError.Success)
71+
if (tcs.TrySetCanceled(cancellationToken))
72+
{
73+
socket.Dispose();
74+
}
75+
},
76+
useSynchronizationContext: false))
8777
{
88-
throw new SocketException((int)SocketError);
78+
_ = await tcs.Task.ConfigureAwait(false);
8979
}
9080
}
91-
}
9281

93-
public static async Task ConnectAsync(this Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
94-
{
95-
cancellationToken.ThrowIfCancellationRequested();
96-
97-
using (var args = new AwaitableSocketAsyncEventArgs())
82+
if (args.SocketError != SocketError.Success)
9883
{
99-
args.RemoteEndPoint = remoteEndpoint;
100-
101-
#if NET || NETSTANDARD2_1_OR_GREATER
102-
await using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs)o).SetCancelled(), args, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false))
103-
#else
104-
using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs) o).SetCancelled(), args, useSynchronizationContext: false))
105-
#endif // NET || NETSTANDARD2_1_OR_GREATER
106-
{
107-
await args.ExecuteAsync(socket.ConnectAsync);
108-
}
84+
throw new SocketException((int) args.SocketError);
10985
}
86+
87+
return args.BytesTransferred;
11088
}
11189

112-
public static async Task<int> ReceiveAsync(this Socket socket, byte[] buffer, int offset, int length, CancellationToken cancellationToken)
90+
public static async ValueTask<int> SendAsync(this Socket socket, byte[] buffer, SocketFlags socketFlags, CancellationToken cancellationToken)
11391
{
11492
cancellationToken.ThrowIfCancellationRequested();
11593

116-
using (var args = new AwaitableSocketAsyncEventArgs())
117-
{
118-
args.SetBuffer(buffer, offset, length);
94+
TaskCompletionSource<object> tcs = new(TaskCreationOptions.RunContinuationsAsynchronously);
11995

120-
#if NET || NETSTANDARD2_1_OR_GREATER
121-
await using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs) o).SetCancelled(), args, useSynchronizationContext: false).ConfigureAwait(continueOnCapturedContext: false))
122-
#else
123-
using (cancellationToken.Register(o => ((AwaitableSocketAsyncEventArgs) o).SetCancelled(), args, useSynchronizationContext: false))
124-
#endif // NET || NETSTANDARD2_1_OR_GREATER
96+
using var args = new SocketAsyncEventArgs();
97+
args.SocketFlags = socketFlags;
98+
args.Completed += (_, _) => tcs.TrySetResult(null);
99+
args.SetBuffer(buffer, 0, buffer.Length);
100+
101+
if (socket.SendAsync(args))
102+
{
103+
using (cancellationToken.Register(() =>
104+
{
105+
if (tcs.TrySetCanceled(cancellationToken))
106+
{
107+
socket.Dispose();
108+
}
109+
},
110+
useSynchronizationContext: false))
125111
{
126-
await args.ExecuteAsync(socket.ReceiveAsync);
112+
_ = await tcs.Task.ConfigureAwait(false);
127113
}
114+
}
128115

129-
return args.BytesTransferred;
116+
if (args.SocketError != SocketError.Success)
117+
{
118+
throw new SocketException((int) args.SocketError);
130119
}
120+
121+
return args.BytesTransferred;
131122
}
123+
#endif // NETFRAMEWORK || NETSTANDARD2_0
132124
}
133125
}
134126
#endif

0 commit comments

Comments
 (0)