Skip to content

Commit 395bbe7

Browse files
committed
Only use Socket.Poll, Socket.Select and read lock when FEATURE_SOCKET_POLL is defined.
Added test for SSH server shutdown while we're reading the packet.
1 parent f9ad893 commit 395bbe7

File tree

5 files changed

+204
-10
lines changed

5 files changed

+204
-10
lines changed

src/Renci.SshNet.Tests/Classes/SessionTest_ConnectedBase.cs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,22 @@ public void Setup()
4747
[TestCleanup]
4848
public void TearDown()
4949
{
50+
if (ServerSocket != null)
51+
{
52+
ServerSocket.Dispose();
53+
ServerSocket = null;
54+
}
55+
5056
if (ServerListener != null)
5157
{
5258
ServerListener.Dispose();
59+
ServerListener = null;
5360
}
5461

5562
if (Session != null)
5663
{
5764
Session.Dispose();
65+
Session = null;
5866
}
5967
}
6068

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
using System;
2+
using System.Diagnostics;
3+
using System.Net.Sockets;
4+
using System.Threading;
5+
using Microsoft.VisualStudio.TestTools.UnitTesting;
6+
using Renci.SshNet.Common;
7+
using Renci.SshNet.Messages.Transport;
8+
9+
namespace Renci.SshNet.Tests.Classes
10+
{
11+
[TestClass]
12+
public class SessionTest_Connected_ServerShutsDownSendAfterSendingIncompletePacket : SessionTest_ConnectedBase
13+
{
14+
protected override void Act()
15+
{
16+
var incompletePacket = new byte[] {0x0a, 0x05, 0x05};
17+
ServerSocket.Send(incompletePacket, 0, incompletePacket.Length, SocketFlags.None);
18+
19+
// give session some time to start reading packet
20+
Thread.Sleep(100);
21+
22+
ServerSocket.Shutdown(SocketShutdown.Send);
23+
24+
// give session some time to process shut down of server socket
25+
Thread.Sleep(100);
26+
}
27+
28+
[TestMethod]
29+
public void IsConnectedShouldReturnFalse()
30+
{
31+
Assert.IsFalse(Session.IsConnected);
32+
}
33+
34+
[TestMethod]
35+
public void DisconnectShouldFinishImmediately()
36+
{
37+
var stopwatch = new Stopwatch();
38+
stopwatch.Start();
39+
40+
Session.Disconnect();
41+
42+
stopwatch.Stop();
43+
Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
44+
}
45+
46+
[TestMethod]
47+
public void DisconnectedIsNeverRaised()
48+
{
49+
Assert.AreEqual(0, DisconnectedRegister.Count);
50+
}
51+
52+
[TestMethod]
53+
public void DisconnectReceivedIsNeverRaised()
54+
{
55+
Assert.AreEqual(0, DisconnectReceivedRegister.Count);
56+
}
57+
58+
[TestMethod]
59+
public void ErrorOccurredIsRaisedOnce()
60+
{
61+
Assert.AreEqual(1, ErrorOccurredRegister.Count);
62+
63+
var errorOccurred = ErrorOccurredRegister[0];
64+
Assert.IsNotNull(errorOccurred);
65+
66+
var exception = errorOccurred.Exception;
67+
Assert.IsNotNull(exception);
68+
Assert.AreEqual(typeof(SshConnectionException), exception.GetType());
69+
70+
var connectionException = (SshConnectionException) exception;
71+
Assert.AreEqual(DisconnectReason.ConnectionLost, connectionException.DisconnectReason);
72+
Assert.IsNull(connectionException.InnerException);
73+
Assert.AreEqual("An established connection was aborted by the server.", connectionException.Message);
74+
}
75+
76+
[TestMethod]
77+
public void DisposeShouldFinishImmediately()
78+
{
79+
var stopwatch = new Stopwatch();
80+
stopwatch.Start();
81+
82+
Session.Dispose();
83+
84+
stopwatch.Stop();
85+
Assert.IsTrue(stopwatch.ElapsedMilliseconds < 500);
86+
}
87+
88+
[TestMethod]
89+
public void ReceiveOnServerSocketShouldReturnZero()
90+
{
91+
var buffer = new byte[1];
92+
93+
var actual = ServerSocket.Receive(buffer, 0, buffer.Length, SocketFlags.None);
94+
95+
Assert.AreEqual(0, actual);
96+
}
97+
98+
[TestMethod]
99+
public void SendMessageShouldSucceed()
100+
{
101+
try
102+
{
103+
Session.SendMessage(new IgnoreMessage());
104+
Assert.Fail();
105+
}
106+
catch (SshConnectionException ex)
107+
{
108+
Assert.IsNull(ex.InnerException);
109+
Assert.AreEqual("Client not connected.", ex.Message);
110+
}
111+
}
112+
113+
[TestMethod]
114+
public void ISession_MessageListenerCompletedShouldBeSignaled()
115+
{
116+
var session = (ISession) Session;
117+
118+
Assert.IsNotNull(session.MessageListenerCompleted);
119+
Assert.IsTrue(session.MessageListenerCompleted.WaitOne());
120+
}
121+
122+
[TestMethod]
123+
public void ISession_SendMessageShouldSucceed()
124+
{
125+
var session = (ISession) Session;
126+
127+
try
128+
{
129+
session.SendMessage(new IgnoreMessage());
130+
Assert.Fail();
131+
}
132+
catch (SshConnectionException ex)
133+
{
134+
Assert.IsNull(ex.InnerException);
135+
Assert.AreEqual("Client not connected.", ex.Message);
136+
}
137+
}
138+
139+
[TestMethod]
140+
public void ISession_TrySendMessageShouldReturnTrue()
141+
{
142+
var session = (ISession) Session;
143+
144+
Assert.IsFalse(session.TrySendMessage(new IgnoreMessage()));
145+
}
146+
147+
[TestMethod]
148+
public void ISession_WaitOnHandleShouldThrowSshConnectionExceptionDetailingBadPacket()
149+
{
150+
var session = (ISession) Session;
151+
var waitHandle = new ManualResetEvent(false);
152+
153+
try
154+
{
155+
session.WaitOnHandle(waitHandle);
156+
Assert.Fail();
157+
}
158+
catch (SshConnectionException ex)
159+
{
160+
Assert.AreEqual(DisconnectReason.ConnectionLost, ex.DisconnectReason);
161+
Assert.IsNull(ex.InnerException);
162+
Assert.AreEqual("An established connection was aborted by the server.", ex.Message);
163+
}
164+
}
165+
}
166+
}

src/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerShutsDownSocket.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using System;
2-
using System.Diagnostics;
1+
using System.Diagnostics;
32
using System.Net.Sockets;
43
using System.Threading;
54
using Microsoft.VisualStudio.TestTools.UnitTesting;

src/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@
258258
<Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessage.cs" />
259259
<Compile Include="Classes\SessionTest_Connected_ServerSendsBadPacket.cs" />
260260
<Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessageAndShutsDownSocket.cs" />
261+
<Compile Include="Classes\SessionTest_Connected_ServerShutsDownSendAfterSendingIncompletePacket.cs" />
261262
<Compile Include="Classes\SessionTest_Connected_ServerShutsDownSocket.cs" />
262263
<Compile Include="Classes\SessionTest_NotConnected.cs" />
263264
<Compile Include="Classes\SessionTest_SocketConnected_BadPacketAndDispose.cs" />

src/Renci.SshNet/Session.cs

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,13 @@ public partial class Session : ISession
162162
/// </summary>
163163
private Socket _socket;
164164

165+
#if FEATURE_SOCKET_POLL
165166
/// <summary>
166167
/// Holds an object that is used to ensure only a single thread can read from
167168
/// <see cref="_socket"/> at any given time.
168169
/// </summary>
169170
private readonly object _socketReadLock = new object();
171+
#endif // FEATURE_SOCKET_POLL
170172

171173
/// <summary>
172174
/// Holds an object that is used to ensure only a single thread can write to
@@ -918,8 +920,13 @@ private Message ReceiveMessage()
918920
byte[] data;
919921
uint packetLength;
920922

923+
#if FEATURE_SOCKET_POLL
924+
// avoid reading from socket while IsSocketConnected is attempting to determine whether the
925+
// socket is still connected by invoking Socket.Poll(...) and subsequently verifying value of
926+
// Socket.Available
921927
lock (_socketReadLock)
922928
{
929+
#endif // FEATURE_SOCKET_POLL
923930
// Read first block - which starts with the packet length
924931
var firstBlock = SocketRead(blockSize);
925932

@@ -938,12 +945,14 @@ private Message ReceiveMessage()
938945
packetLength = (uint) (firstBlock[0] << 24 | firstBlock[1] << 16 | firstBlock[2] << 8 | firstBlock[3]);
939946

940947
// Test packet minimum and maximum boundaries
941-
if (packetLength < Math.Max((byte)16, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
942-
throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength), DisconnectReason.ProtocolError);
948+
if (packetLength < Math.Max((byte) 16, blockSize) - 4 || packetLength > MaximumSshPacketSize - 4)
949+
throw new SshConnectionException(
950+
string.Format(CultureInfo.CurrentCulture, "Bad packet length: {0}.", packetLength),
951+
DisconnectReason.ProtocolError);
943952

944953
// Determine the number of bytes left to read; We've already read "blockSize" bytes, but the
945954
// "packet length" field itself - which is 4 bytes - is not included in the length of the packet
946-
var bytesToRead = (int)(packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength;
955+
var bytesToRead = (int) (packetLength - (blockSize - packetLengthFieldLength)) + serverMacLength;
947956

948957
// Construct buffer for holding the payload and the inbound packet sequence as we need both in order
949958
// to generate the hash.
@@ -964,7 +973,9 @@ private Message ReceiveMessage()
964973
{
965974
SocketRead(data, blockSize + inboundPacketSequenceLength, bytesToRead);
966975
}
976+
#if FEATURE_SOCKET_POLL
967977
}
978+
#endif // FEATURE_SOCKET_POLL
968979

969980
if (_serverCipher != null)
970981
{
@@ -1856,14 +1867,19 @@ private void SocketDisconnectAndDispose()
18561867
// interrupt any pending reads
18571868
_socket.Shutdown(SocketShutdown.Send);
18581869

1859-
// since we've shut down the socket, there should not be
1860-
// any reads in progress but we still take a read lock
1861-
// to ensure IsSocketConnected continues to provide
1870+
#if FEATURE_SOCKET_POLL
1871+
// since we've shut down the socket, there should not be any reads in progress but
1872+
// we still take a read lock to ensure IsSocketConnected continues to provide
18621873
// correct results
1874+
//
1875+
// only necessary if IsSocketConnected actually uses Socket.Poll.
18631876
lock (_socketReadLock)
18641877
{
1878+
#endif // FEATURE_SOCKET_POLL
18651879
SocketAbstraction.ClearReadBuffer(_socket);
1880+
#if FEATURE_SOCKET_POLL
18661881
}
1882+
#endif // FEATURE_SOCKET_POLL
18671883
}
18681884

18691885
_socket.Dispose();
@@ -1882,16 +1898,20 @@ private void MessageListener()
18821898
{
18831899
var readSockets = new List<Socket> {_socket};
18841900

1885-
while (_socket != null)
1901+
// remain in message loop until socket is shut down
1902+
while (true)
18861903
{
1904+
#if FEATURE_SOCKET_POLL
18871905
Socket.Select(readSockets, null, null, -1);
18881906

18891907
if (readSockets.Count == 0)
18901908
break;
18911909

18921910
// when the socket is disposed while a Select is executing, then the
18931911
// Select will be interrupted; the socket will not be removed from
1894-
// readSocket
1912+
// readSocket, and therefore we need to explicitly check if the
1913+
// socket is still connected
1914+
#endif // FEATURE_SOCKET_POLL
18951915
var socket = _socket;
18961916
if (socket == null || !socket.Connected)
18971917
break;

0 commit comments

Comments
 (0)