Skip to content

Commit 9e8ce4a

Browse files
committed
Modify ReceiveMessage to return null when connection is closed.
Added TrySocketRead method that returns 0 (zeo) when connection is closed. Remove SocketRead(int length) overload. Added IsConnected extension method to Socket. Modify MessageListener() to use this extension method as condition for the message loop. Do not bother checking readSockets as the connected check of the socket allows us to combine both the connection closed and socket disposed conditions.
1 parent b3d24ef commit 9e8ce4a

File tree

3 files changed

+88
-36
lines changed

3 files changed

+88
-36
lines changed

src/Renci.SshNet/Common/Extensions.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,5 +327,12 @@ internal static bool CanWrite(this Socket socket)
327327
{
328328
return SocketAbstraction.CanWrite(socket);
329329
}
330+
331+
internal static bool IsConnected(this Socket socket)
332+
{
333+
if (socket == null)
334+
return false;
335+
return socket.Connected;
336+
}
330337
}
331338
}

src/Renci.SshNet/Session.NET.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,19 @@ partial void IsSocketConnected(ref bool isConnected)
6161
lock (_socketDisposeLock)
6262
{
6363
#if FEATURE_SOCKET_POLL
64-
if (_socket == null || !_socket.Connected)
64+
if (!_socket.IsConnected())
6565
{
6666
isConnected = false;
6767
return;
6868
}
6969

7070
lock (_socketReadLock)
7171
{
72-
var connectionClosedOrDataAvailable = _socket.Poll(1, SelectMode.SelectRead);
72+
var connectionClosedOrDataAvailable = _socket.Poll(0, SelectMode.SelectRead);
7373
isConnected = !(connectionClosedOrDataAvailable && _socket.Available == 0);
7474
}
7575
#else
76-
isConnected = _socket != null && _socket.Connected;
76+
isConnected = _socket.IsConnected();
7777
#endif // FEATURE_SOCKET_POLL
7878
}
7979

src/Renci.SshNet/Session.cs

Lines changed: 78 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ internal void SendMessage(Message message)
829829
// atomically, and only after the packet has actually been sent
830830
lock (_socketWriteLock)
831831
{
832-
if (_socket == null || !_socket.Connected)
832+
if (!_socket.IsConnected())
833833
throw new SshConnectionException("Client not connected.");
834834

835835
byte[] hash = null;
@@ -914,9 +914,8 @@ private bool TrySendMessage(Message message)
914914
/// Receives the message from the server.
915915
/// </summary>
916916
/// <returns>
917-
/// The incoming SSH message.
917+
/// The incoming SSH message, or <c>null</c> if the connection with the SSH server was closed.
918918
/// </returns>
919-
/// <exception cref="SshConnectionException"></exception>
920919
/// <remarks>
921920
/// We need no locking here since all messages are read by a single thread.
922921
/// </remarks>
@@ -945,7 +944,12 @@ private Message ReceiveMessage()
945944
{
946945
#endif // FEATURE_SOCKET_POLL
947946
// Read first block - which starts with the packet length
948-
var firstBlock = SocketRead(blockSize);
947+
var firstBlock = new byte[blockSize];
948+
if (TrySocketRead(firstBlock, 0, blockSize) == 0)
949+
{
950+
// connection with SSH server was closed
951+
return null;
952+
}
949953

950954
#if DEBUG_GERT
951955
DiagnosticAbstraction.Log(string.Format("[{0}] FirstBlock [{1}]: {2}", ToHex(SessionId), blockSize, ToHex(firstBlock)));
@@ -988,7 +992,10 @@ private Message ReceiveMessage()
988992

989993
if (bytesToRead > 0)
990994
{
991-
SocketRead(data, blockSize + inboundPacketSequenceLength, bytesToRead);
995+
if (TrySocketRead(data, blockSize + inboundPacketSequenceLength, bytesToRead) == 0)
996+
{
997+
return null;
998+
}
992999
}
9931000
#if FEATURE_SOCKET_POLL
9941001
}
@@ -1787,20 +1794,6 @@ private void SocketConnect(string host, int port)
17871794
_socket.ReceiveBufferSize = socketBufferSize;
17881795
}
17891796

1790-
/// <summary>
1791-
/// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
1792-
/// </summary>
1793-
/// <param name="length">The number of bytes to read.</param>
1794-
/// <returns>
1795-
/// The bytes read from the server.
1796-
/// </returns>
1797-
private byte[] SocketRead(int length)
1798-
{
1799-
var buffer = new byte[length];
1800-
SocketRead(buffer, 0, length);
1801-
return buffer;
1802-
}
1803-
18041797
/// <summary>
18051798
/// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
18061799
/// </summary>
@@ -1827,6 +1820,22 @@ private int SocketRead(byte[] buffer, int offset, int length)
18271820
return bytesRead;
18281821
}
18291822

1823+
/// <summary>
1824+
/// Performs a blocking read on the socket until <paramref name="length"/> bytes are received.
1825+
/// </summary>
1826+
/// <param name="buffer">An array of type <see cref="byte"/> that is the storage location for the received data.</param>
1827+
/// <param name="offset">The position in <paramref name="buffer"/> parameter to store the received data.</param>
1828+
/// <param name="length">The number of bytes to read.</param>
1829+
/// <returns>
1830+
/// The number of bytes read.
1831+
/// </returns>
1832+
/// <exception cref="SshOperationTimeoutException">The read has timed-out.</exception>
1833+
/// <exception cref="SocketException">The read failed.</exception>
1834+
private int TrySocketRead(byte[] buffer, int offset, int length)
1835+
{
1836+
return SocketAbstraction.Read(_socket, buffer, offset, length, InfiniteTimeSpan);
1837+
}
1838+
18301839
/// <summary>
18311840
/// Performs a blocking read on the socket until a line is read.
18321841
/// </summary>
@@ -1920,26 +1929,56 @@ private void MessageListener()
19201929
var readSockets = new List<Socket> {_socket};
19211930

19221931
// remain in message loop until socket is shut down or until we're disconnecting
1923-
while (true)
1932+
while (_socket.IsConnected())
19241933
{
19251934
#if FEATURE_SOCKET_POLL
1935+
// if the socket is already disposed when Select is invoked, then a SocketException
1936+
// stating "An operation was attempted on something that is not a socket" is thrown;
1937+
// we attempt to avoid this exception by having an IsConnected() that can break the
1938+
// message loop
1939+
//
1940+
// note that there's no guarantee that the socket will not be disposed between the
1941+
// IsConnected() check and the Select invocation; we can't take a "dispose" lock
1942+
// that includes the Select invocation as we want Dispose() to be able to interrupt
1943+
// the Select
1944+
1945+
// perform a blocking select to determine whether there's is data available to be
1946+
// read; we do not use a blocking read to allow us to use Socket.Poll to determine
1947+
// if the connection is still available (in IsSocketConnected
19261948
Socket.Select(readSockets, null, null, -1);
19271949

1928-
if (readSockets.Count == 0)
1950+
// the Select invocation will be interrupted in one of the following conditions:
1951+
// * data is available to be read
1952+
// => the socket will not be removed from "readSockets"
1953+
// * the socket connection is closed during the Select invocation
1954+
// => the socket will be removed from "readSockets"
1955+
// * the socket is disposed during the Select invocation
1956+
// => the socket will not be removed from "readSocket"
1957+
//
1958+
// since we handle the second and third condition the same way and Socket.Connected
1959+
// allows us to check for both conditions, we use that instead of both checking for
1960+
// the removal from "readSockets" and the Connection check
1961+
if (!_socket.IsConnected())
1962+
{
1963+
// connection with SSH server was closed or socket was disposed;
1964+
// break out of the message loop
19291965
break;
1930-
1931-
// when the socket is disposed while a Select is executing, then the
1932-
// Select will be interrupted; the socket will not be removed from
1933-
// readSocket, and therefore we need to explicitly check if the
1934-
// socket is still connected
1966+
}
19351967
#endif // FEATURE_SOCKET_POLL
1936-
var socket = _socket;
1937-
if (socket == null || !socket.Connected)
1938-
break;
19391968

19401969
var message = ReceiveMessage();
1970+
if (message == null)
1971+
{
1972+
// connection with SSH server was closed;
1973+
// break out of the message loop
1974+
break;
1975+
}
1976+
19411977
HandleMessageCore(message);
19421978
}
1979+
1980+
// connection with SSH server was closed
1981+
RaiseError(CreateConnectionAbortedByServerException());
19431982
}
19441983
catch (SocketException ex)
19451984
{
@@ -2305,7 +2344,13 @@ private void Reset()
23052344
_keyExchangeInProgress = false;
23062345
}
23072346

2308-
#region IDisposable implementation
2347+
private static SshConnectionException CreateConnectionAbortedByServerException()
2348+
{
2349+
return new SshConnectionException("An established connection was aborted by the server.",
2350+
DisconnectReason.ConnectionLost);
2351+
}
2352+
2353+
#region IDisposable implementation
23092354

23102355
private bool _disposed;
23112356

@@ -2396,9 +2441,9 @@ protected virtual void Dispose(bool disposing)
23962441
Dispose(false);
23972442
}
23982443

2399-
#endregion IDisposable implementation
2444+
#endregion IDisposable implementation
24002445

2401-
#region ISession implementation
2446+
#region ISession implementation
24022447

24032448
/// <summary>
24042449
/// Gets or sets the connection info.
@@ -2483,6 +2528,6 @@ bool ISession.TrySendMessage(Message message)
24832528
return TrySendMessage(message);
24842529
}
24852530

2486-
#endregion ISession implementation
2531+
#endregion ISession implementation
24872532
}
24882533
}

0 commit comments

Comments
 (0)