Skip to content

Commit fea0406

Browse files
committed
Cleanup SocketAbstraction
* Use "using" and ManualResetEventSlim in Connect * Delete unused and unnecessary methods
1 parent dfa72c3 commit fea0406

File tree

9 files changed

+48
-192
lines changed

9 files changed

+48
-192
lines changed

src/Renci.SshNet/Abstractions/SocketAbstraction.cs

Lines changed: 10 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -3,148 +3,43 @@
33
using System.Net;
44
using System.Net.Sockets;
55
using System.Threading;
6+
#if NET6_0_OR_GREATER == false
67
using System.Threading.Tasks;
8+
#endif
79

810
using Renci.SshNet.Common;
911

1012
namespace Renci.SshNet.Abstractions
1113
{
1214
internal static partial class SocketAbstraction
1315
{
14-
public static bool CanRead(Socket socket)
15-
{
16-
if (socket.Connected)
17-
{
18-
return socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0;
19-
}
20-
21-
return false;
22-
}
23-
24-
/// <summary>
25-
/// Returns a value indicating whether the specified <see cref="Socket"/> can be used
26-
/// to send data.
27-
/// </summary>
28-
/// <param name="socket">The <see cref="Socket"/> to check.</param>
29-
/// <returns>
30-
/// <see langword="true"/> if <paramref name="socket"/> can be written to; otherwise, <see langword="false"/>.
31-
/// </returns>
32-
public static bool CanWrite(Socket socket)
33-
{
34-
if (socket != null && socket.Connected)
35-
{
36-
return socket.Poll(-1, SelectMode.SelectWrite);
37-
}
38-
39-
return false;
40-
}
41-
42-
public static Socket Connect(IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
43-
{
44-
var socket = new Socket(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp) { NoDelay = true };
45-
ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: true);
46-
return socket;
47-
}
48-
4916
public static void Connect(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout)
5017
{
51-
ConnectCore(socket, remoteEndpoint, connectTimeout, ownsSocket: false);
52-
}
53-
54-
public static async Task ConnectAsync(Socket socket, IPEndPoint remoteEndpoint, CancellationToken cancellationToken)
55-
{
56-
await socket.ConnectAsync(remoteEndpoint, cancellationToken).ConfigureAwait(false);
57-
}
58-
59-
private static void ConnectCore(Socket socket, IPEndPoint remoteEndpoint, TimeSpan connectTimeout, bool ownsSocket)
60-
{
61-
var connectCompleted = new ManualResetEvent(initialState: false);
62-
var args = new SocketAsyncEventArgs
63-
{
64-
UserToken = connectCompleted,
65-
RemoteEndPoint = remoteEndpoint
66-
};
67-
args.Completed += ConnectCompleted;
18+
using var connectCompleted = new ManualResetEventSlim(initialState: false);
19+
using var args = new SocketAsyncEventArgs
20+
{
21+
RemoteEndPoint = remoteEndpoint
22+
};
23+
args.Completed += (_, _) => connectCompleted.Set();
6824

6925
if (socket.ConnectAsync(args))
7026
{
71-
if (!connectCompleted.WaitOne(connectTimeout))
27+
if (!connectCompleted.Wait(connectTimeout))
7228
{
73-
// avoid ObjectDisposedException in ConnectCompleted
74-
args.Completed -= ConnectCompleted;
75-
if (ownsSocket)
76-
{
77-
// dispose Socket
78-
socket.Dispose();
79-
}
80-
81-
// dispose ManualResetEvent
82-
connectCompleted.Dispose();
83-
84-
// dispose SocketAsyncEventArgs
85-
args.Dispose();
29+
socket.Dispose();
8630

8731
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
8832
"Connection failed to establish within {0:F0} milliseconds.",
8933
connectTimeout.TotalMilliseconds));
9034
}
9135
}
9236

93-
// dispose ManualResetEvent
94-
connectCompleted.Dispose();
95-
9637
if (args.SocketError != SocketError.Success)
9738
{
9839
var socketError = (int) args.SocketError;
9940

100-
if (ownsSocket)
101-
{
102-
// dispose Socket
103-
socket.Dispose();
104-
}
105-
106-
// dispose SocketAsyncEventArgs
107-
args.Dispose();
108-
10941
throw new SocketException(socketError);
11042
}
111-
112-
// dispose SocketAsyncEventArgs
113-
args.Dispose();
114-
}
115-
116-
public static void ClearReadBuffer(Socket socket)
117-
{
118-
var timeout = TimeSpan.FromMilliseconds(500);
119-
var buffer = new byte[256];
120-
int bytesReceived;
121-
122-
do
123-
{
124-
bytesReceived = ReadPartial(socket, buffer, 0, buffer.Length, timeout);
125-
}
126-
while (bytesReceived > 0);
127-
}
128-
129-
public static int ReadPartial(Socket socket, byte[] buffer, int offset, int size, TimeSpan timeout)
130-
{
131-
socket.ReceiveTimeout = timeout.AsTimeout(nameof(timeout));
132-
133-
try
134-
{
135-
return socket.Receive(buffer, offset, size, SocketFlags.None);
136-
}
137-
catch (SocketException ex)
138-
{
139-
if (ex.SocketErrorCode == SocketError.TimedOut)
140-
{
141-
throw new SshOperationTimeoutException(string.Format(CultureInfo.InvariantCulture,
142-
"Socket read operation has timed out after {0:F0} milliseconds.",
143-
timeout.TotalMilliseconds));
144-
}
145-
146-
throw;
147-
}
14843
}
14944

15045
public static void ReadContinuous(Socket socket, byte[] buffer, int offset, int size, Action<byte[], int, int> processReceivedBytesAction)
@@ -206,41 +101,6 @@ public static int ReadByte(Socket socket, TimeSpan timeout)
206101
return buffer[0];
207102
}
208103

209-
/// <summary>
210-
/// Sends a byte using the specified <see cref="Socket"/>.
211-
/// </summary>
212-
/// <param name="socket">The <see cref="Socket"/> to write to.</param>
213-
/// <param name="value">The value to send.</param>
214-
/// <exception cref="SocketException">The write failed.</exception>
215-
public static void SendByte(Socket socket, byte value)
216-
{
217-
var buffer = new[] { value };
218-
_ = socket.Send(buffer);
219-
}
220-
221-
/// <summary>
222-
/// Receives data from a bound <see cref="Socket"/>.
223-
/// </summary>
224-
/// <param name="socket">The <see cref="Socket"/> to read from.</param>
225-
/// <param name="size">The number of bytes to receive.</param>
226-
/// <param name="timeout">Specifies the amount of time after which the call will time out.</param>
227-
/// <returns>
228-
/// The bytes received.
229-
/// </returns>
230-
/// <remarks>
231-
/// If no data is available for reading, the <see cref="Read(Socket, int, TimeSpan)"/> method will
232-
/// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the
233-
/// <see cref="Read(Socket, int, TimeSpan)"/> call will throw a <see cref="SshOperationTimeoutException"/>.
234-
/// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the
235-
/// <see cref="Read(Socket, int, TimeSpan)"/> method will complete immediately and throw a <see cref="SocketException"/>.
236-
/// </remarks>
237-
public static byte[] Read(Socket socket, int size, TimeSpan timeout)
238-
{
239-
var buffer = new byte[size];
240-
_ = Read(socket, buffer, 0, size, timeout);
241-
return buffer;
242-
}
243-
244104
/// <summary>
245105
/// Receives data from a bound <see cref="Socket"/> into a receive buffer.
246106
/// </summary>
@@ -258,10 +118,6 @@ public static byte[] Read(Socket socket, int size, TimeSpan timeout)
258118
/// block until data is available or the time-out value is exceeded. If the time-out value is exceeded, the
259119
/// <see cref="Read(Socket, byte[], int, int, TimeSpan)"/> call will throw a <see cref="SshOperationTimeoutException"/>.
260120
/// </para>
261-
/// <para>
262-
/// If you are in non-blocking mode, and there is no data available in the in the protocol stack buffer, the
263-
/// <see cref="Read(Socket, byte[], int, int, TimeSpan)"/> method will complete immediately and throw a <see cref="SocketException"/>.
264-
/// </para>
265121
/// </remarks>
266122
public static int Read(Socket socket, byte[] buffer, int offset, int size, TimeSpan readTimeout)
267123
{
@@ -301,10 +157,5 @@ public static ValueTask<int> ReadAsync(Socket socket, byte[] buffer, Cancellatio
301157
return socket.ReceiveAsync(new ArraySegment<byte>(buffer, 0, buffer.Length), SocketFlags.None, cancellationToken);
302158
}
303159
#endif
304-
private static void ConnectCompleted(object sender, SocketAsyncEventArgs e)
305-
{
306-
var eventWaitHandle = (ManualResetEvent) e.UserToken;
307-
_ = eventWaitHandle?.Set();
308-
}
309160
}
310161
}

src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System.Net.Sockets;
44
using Renci.SshNet.Abstractions;
55
using Renci.SshNet.Common;
6+
using Renci.SshNet.Connection;
67
using Renci.SshNet.Messages.Connection;
78

89
namespace Renci.SshNet.Channels
@@ -13,20 +14,23 @@ namespace Renci.SshNet.Channels
1314
internal sealed class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTcpip
1415
{
1516
private readonly object _socketShutdownAndCloseLock = new object();
17+
private readonly ISocketFactory _socketFactory;
1618
private Socket _socket;
1719
private IForwardedPort _forwardedPort;
1820

1921
/// <summary>
2022
/// Initializes a new instance of the <see cref="ChannelForwardedTcpip"/> class.
2123
/// </summary>
2224
/// <param name="session">The session.</param>
25+
/// <param name="socketFactory">The socket factory.</param>
2326
/// <param name="localChannelNumber">The local channel number.</param>
2427
/// <param name="localWindowSize">Size of the window.</param>
2528
/// <param name="localPacketSize">Size of the packet.</param>
2629
/// <param name="remoteChannelNumber">The remote channel number.</param>
2730
/// <param name="remoteWindowSize">The window size of the remote party.</param>
2831
/// <param name="remotePacketSize">The maximum size of a data packet that we can send to the remote party.</param>
2932
internal ChannelForwardedTcpip(ISession session,
33+
ISocketFactory socketFactory,
3034
uint localChannelNumber,
3135
uint localWindowSize,
3236
uint localPacketSize,
@@ -41,6 +45,7 @@ internal ChannelForwardedTcpip(ISession session,
4145
remoteWindowSize,
4246
remotePacketSize)
4347
{
48+
_socketFactory = socketFactory;
4449
}
4550

4651
/// <summary>
@@ -72,7 +77,9 @@ public void Bind(IPEndPoint remoteEndpoint, IForwardedPort forwardedPort)
7277
// Try to connect to the socket
7378
try
7479
{
75-
_socket = SocketAbstraction.Connect(remoteEndpoint, ConnectionInfo.Timeout);
80+
_socket = _socketFactory.Create(remoteEndpoint.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
81+
82+
SocketAbstraction.Connect(_socket, remoteEndpoint, ConnectionInfo.Timeout);
7683

7784
// Send channel open confirmation message
7885
SendMessage(new ChannelOpenConfirmationMessage(RemoteChannelNumber, LocalWindowSize, LocalPacketSize, LocalChannelNumber));

src/Renci.SshNet/Common/Extensions.cs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
using System.Net;
66
using System.Net.Sockets;
77
using System.Text;
8-
using Renci.SshNet.Abstractions;
8+
99
using Renci.SshNet.Messages;
1010

1111
namespace Renci.SshNet.Common
@@ -336,22 +336,17 @@ public static byte[] Concat(this byte[] first, byte[] second)
336336

337337
internal static bool CanRead(this Socket socket)
338338
{
339-
return SocketAbstraction.CanRead(socket);
339+
return socket.Connected && socket.Poll(-1, SelectMode.SelectRead) && socket.Available > 0;
340340
}
341341

342342
internal static bool CanWrite(this Socket socket)
343343
{
344-
return SocketAbstraction.CanWrite(socket);
344+
return socket is not null && socket.Connected && socket.Poll(-1, SelectMode.SelectWrite);
345345
}
346346

347347
internal static bool IsConnected(this Socket socket)
348348
{
349-
if (socket is null)
350-
{
351-
return false;
352-
}
353-
354-
return socket.Connected;
349+
return socket is not null && socket.Connected;
355350
}
356351
}
357352
}

src/Renci.SshNet/Connection/ConnectorBase.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ protected async Task<Socket> SocketConnectAsync(string host, int port, Cancellat
8686
var socket = SocketFactory.Create(ep.AddressFamily, SocketType.Stream, ProtocolType.Tcp);
8787
try
8888
{
89-
await SocketAbstraction.ConnectAsync(socket, ep, cancellationToken).ConfigureAwait(false);
89+
await socket.ConnectAsync(ep, cancellationToken).ConfigureAwait(false);
9090

9191
const int socketBufferSize = 2 * Session.MaximumSshPacketSize;
9292
socket.SendBufferSize = socketBufferSize;

src/Renci.SshNet/Connection/Socks5Connector.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ protected override void HandleProxyConnect(IConnectionInfo connectionInfo, Socke
6363
_ = socket.Send(authenticationRequest);
6464

6565
// Read authentication result
66-
var authenticationResult = SocketAbstraction.Read(socket, 2, connectionInfo.Timeout);
66+
var authenticationResult = new byte[2];
67+
_ = SocketAbstraction.Read(socket, authenticationResult, 0, authenticationResult.Length, connectionInfo.Timeout);
6768

6869
if (authenticationResult[0] != 0x01)
6970
{

src/Renci.SshNet/ForwardedPortDynamic.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -503,18 +503,18 @@ private bool HandleSocks4(Socket socket, IChannelDirectTcpip channel, TimeSpan t
503503

504504
channel.Open(host, port, this, socket);
505505

506-
SocketAbstraction.SendByte(socket, 0x00);
506+
_ = socket.Send([0x00]);
507507

508508
if (channel.IsOpen)
509509
{
510-
SocketAbstraction.SendByte(socket, 0x5a);
510+
_ = socket.Send([0x5a]);
511511
_ = socket.Send(portBuffer);
512512
_ = socket.Send(ipBuffer);
513513
return true;
514514
}
515515

516516
// signal that request was rejected or failed
517-
SocketAbstraction.SendByte(socket, 0x5b);
517+
_ = socket.Send([0x5b]);
518518
return false;
519519
}
520520

src/Renci.SshNet/Session.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2232,6 +2232,7 @@ IChannelForwardedTcpip ISession.CreateChannelForwardedTcpip(uint remoteChannelNu
22322232
uint remoteChannelDataPacketSize)
22332233
{
22342234
return new ChannelForwardedTcpip(this,
2235+
_socketFactory,
22352236
NextChannelNumber,
22362237
InitialLocalWindowSize,
22372238
LocalChannelDataPacketSize,

0 commit comments

Comments
 (0)