Skip to content

Commit f9ad893

Browse files
committed
Rename Read(int length) to SocketRead(int length)
Introduce dispose lock to resolve race condition in IsConnected/IsSocketConnected. Rename _socketLock to _socketWriteLock Eliminate extra allocations in ReceiveMessage, and combine two socket reads. Use separate lock to eliminate race condition in IsSocketConnected between Poll and checking the Available property. Modify SocketRead(int length, byte[] buffer) to also take offset. Modify MessageListener to use Select instead of blocking Receive. Fixes issue #80.
1 parent 6a1859c commit f9ad893

File tree

5 files changed

+403
-180
lines changed

5 files changed

+403
-180
lines changed
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Globalization;
4+
using System.Net;
5+
using System.Net.Sockets;
6+
using System.Security.Cryptography;
7+
using System.Text;
8+
using Microsoft.VisualStudio.TestTools.UnitTesting;
9+
using Moq;
10+
using Renci.SshNet.Common;
11+
using Renci.SshNet.Compression;
12+
using Renci.SshNet.Messages;
13+
using Renci.SshNet.Messages.Transport;
14+
using Renci.SshNet.Security;
15+
using Renci.SshNet.Security.Cryptography;
16+
using Renci.SshNet.Tests.Common;
17+
18+
namespace Renci.SshNet.Tests.Classes
19+
{
20+
[TestClass]
21+
public class SessionTest_Connected_ServerAndClientDisconnectRace
22+
{
23+
private Mock<IServiceFactory> _serviceFactoryMock;
24+
private Mock<IKeyExchange> _keyExchangeMock;
25+
private Mock<IClientAuthentication> _clientAuthenticationMock;
26+
private IPEndPoint _serverEndPoint;
27+
private string _keyExchangeAlgorithm;
28+
private DisconnectMessage _disconnectMessage;
29+
30+
protected Random Random { get; private set; }
31+
protected byte[] SessionId { get; private set; }
32+
protected ConnectionInfo ConnectionInfo { get; private set; }
33+
protected IList<EventArgs> DisconnectedRegister { get; private set; }
34+
protected IList<MessageEventArgs<DisconnectMessage>> DisconnectReceivedRegister { get; private set; }
35+
protected IList<ExceptionEventArgs> ErrorOccurredRegister { get; private set; }
36+
protected AsyncSocketListener ServerListener { get; private set; }
37+
protected IList<byte[]> ServerBytesReceivedRegister { get; private set; }
38+
protected Session Session { get; private set; }
39+
protected Socket ServerSocket { get; private set; }
40+
41+
private void TearDown()
42+
{
43+
if (ServerListener != null)
44+
{
45+
ServerListener.Dispose();
46+
}
47+
48+
if (Session != null)
49+
{
50+
Session.Dispose();
51+
}
52+
}
53+
54+
protected virtual void SetupData()
55+
{
56+
Random = new Random();
57+
58+
_serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
59+
ConnectionInfo = new ConnectionInfo(
60+
_serverEndPoint.Address.ToString(),
61+
_serverEndPoint.Port,
62+
"user",
63+
new PasswordAuthenticationMethod("user", "password"))
64+
{ Timeout = TimeSpan.FromSeconds(20) };
65+
_keyExchangeAlgorithm = Random.Next().ToString(CultureInfo.InvariantCulture);
66+
SessionId = new byte[10];
67+
Random.NextBytes(SessionId);
68+
DisconnectedRegister = new List<EventArgs>();
69+
DisconnectReceivedRegister = new List<MessageEventArgs<DisconnectMessage>>();
70+
ErrorOccurredRegister = new List<ExceptionEventArgs>();
71+
ServerBytesReceivedRegister = new List<byte[]>();
72+
_disconnectMessage = new DisconnectMessage(DisconnectReason.ServiceNotAvailable, "Not today!");
73+
74+
Session = new Session(ConnectionInfo, _serviceFactoryMock.Object);
75+
Session.Disconnected += (sender, args) => DisconnectedRegister.Add(args);
76+
Session.DisconnectReceived += (sender, args) => DisconnectReceivedRegister.Add(args);
77+
Session.ErrorOccured += (sender, args) => ErrorOccurredRegister.Add(args);
78+
Session.KeyExchangeInitReceived += (sender, args) =>
79+
{
80+
var newKeysMessage = new NewKeysMessage();
81+
var newKeys = newKeysMessage.GetPacket(8, null);
82+
ServerSocket.Send(newKeys, 4, newKeys.Length - 4, SocketFlags.None);
83+
};
84+
85+
ServerListener = new AsyncSocketListener(_serverEndPoint);
86+
ServerListener.Connected += socket =>
87+
{
88+
ServerSocket = socket;
89+
90+
socket.Send(Encoding.ASCII.GetBytes("\r\n"));
91+
socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
92+
socket.Send(Encoding.ASCII.GetBytes("SSH-2.0-SshStub\r\n"));
93+
};
94+
95+
var counter = 0;
96+
97+
ServerListener.BytesReceived += (received, socket) =>
98+
{
99+
ServerBytesReceivedRegister.Add(received);
100+
101+
switch (counter++)
102+
{
103+
case 0:
104+
var keyExchangeInitMessage = new KeyExchangeInitMessage
105+
{
106+
CompressionAlgorithmsClientToServer = new string[0],
107+
CompressionAlgorithmsServerToClient = new string[0],
108+
EncryptionAlgorithmsClientToServer = new string[0],
109+
EncryptionAlgorithmsServerToClient = new string[0],
110+
KeyExchangeAlgorithms = new[] { _keyExchangeAlgorithm },
111+
LanguagesClientToServer = new string[0],
112+
LanguagesServerToClient = new string[0],
113+
MacAlgorithmsClientToServer = new string[0],
114+
MacAlgorithmsServerToClient = new string[0],
115+
ServerHostKeyAlgorithms = new string[0]
116+
};
117+
var keyExchangeInit = keyExchangeInitMessage.GetPacket(8, null);
118+
ServerSocket.Send(keyExchangeInit, 4, keyExchangeInit.Length - 4, SocketFlags.None);
119+
break;
120+
case 1:
121+
var serviceAcceptMessage =ServiceAcceptMessageBuilder.Create(ServiceName.UserAuthentication).Build();
122+
ServerSocket.Send(serviceAcceptMessage, 0, serviceAcceptMessage.Length, SocketFlags.None);
123+
break;
124+
}
125+
};
126+
}
127+
128+
private void CreateMocks()
129+
{
130+
_serviceFactoryMock = new Mock<IServiceFactory>(MockBehavior.Strict);
131+
_keyExchangeMock = new Mock<IKeyExchange>(MockBehavior.Strict);
132+
_clientAuthenticationMock = new Mock<IClientAuthentication>(MockBehavior.Strict);
133+
}
134+
135+
private void SetupMocks()
136+
{
137+
_serviceFactoryMock.Setup(
138+
p =>
139+
p.CreateKeyExchange(ConnectionInfo.KeyExchangeAlgorithms, new[] { _keyExchangeAlgorithm })).Returns(_keyExchangeMock.Object);
140+
_keyExchangeMock.Setup(p => p.Name).Returns(_keyExchangeAlgorithm);
141+
_keyExchangeMock.Setup(p => p.Start(Session, It.IsAny<KeyExchangeInitMessage>()));
142+
_keyExchangeMock.Setup(p => p.ExchangeHash).Returns(SessionId);
143+
_keyExchangeMock.Setup(p => p.CreateServerCipher()).Returns((Cipher)null);
144+
_keyExchangeMock.Setup(p => p.CreateClientCipher()).Returns((Cipher)null);
145+
_keyExchangeMock.Setup(p => p.CreateServerHash()).Returns((HashAlgorithm)null);
146+
_keyExchangeMock.Setup(p => p.CreateClientHash()).Returns((HashAlgorithm)null);
147+
_keyExchangeMock.Setup(p => p.CreateCompressor()).Returns((Compressor)null);
148+
_keyExchangeMock.Setup(p => p.CreateDecompressor()).Returns((Compressor)null);
149+
_keyExchangeMock.Setup(p => p.Dispose());
150+
_serviceFactoryMock.Setup(p => p.CreateClientAuthentication()).Returns(_clientAuthenticationMock.Object);
151+
_clientAuthenticationMock.Setup(p => p.Authenticate(ConnectionInfo, Session));
152+
}
153+
154+
protected virtual void Arrange()
155+
{
156+
CreateMocks();
157+
SetupData();
158+
SetupMocks();
159+
160+
ServerListener.Start();
161+
Session.Connect();
162+
}
163+
164+
[TestMethod]
165+
public void Act()
166+
{
167+
for (var i = 0; i < 50; i++)
168+
{
169+
Arrange();
170+
try
171+
{
172+
var disconnect = _disconnectMessage.GetPacket(8, null);
173+
ServerSocket.Send(disconnect, 4, disconnect.Length - 4, SocketFlags.None);
174+
Session.Disconnect();
175+
}
176+
finally
177+
{
178+
TearDown();
179+
}
180+
}
181+
}
182+
183+
private class ServiceAcceptMessageBuilder
184+
{
185+
private readonly ServiceName _serviceName;
186+
187+
private ServiceAcceptMessageBuilder(ServiceName serviceName)
188+
{
189+
_serviceName = serviceName;
190+
}
191+
192+
public static ServiceAcceptMessageBuilder Create(ServiceName serviceName)
193+
{
194+
return new ServiceAcceptMessageBuilder(serviceName);
195+
}
196+
197+
public byte[] Build()
198+
{
199+
var serviceName = _serviceName.ToArray();
200+
201+
var sshDataStream = new SshDataStream(4 + 1 + 1 + 4 + serviceName.Length);
202+
sshDataStream.Write((uint)(sshDataStream.Capacity - 4)); // packet length
203+
sshDataStream.WriteByte(0); // padding length
204+
sshDataStream.WriteByte(ServiceAcceptMessage.MessageNumber);
205+
sshDataStream.WriteBinary(serviceName);
206+
return sshDataStream.ToArray();
207+
}
208+
}
209+
}
210+
}

src/Renci.SshNet.Tests/Classes/SessionTest_Connected_ServerSendsDisconnectMessage.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
namespace Renci.SshNet.Tests.Classes
1010
{
11-
[TestClass]
1211
public class SessionTest_Connected_ServerSendsDisconnectMessage : SessionTest_ConnectedBase
1312
{
1413
private DisconnectMessage _disconnectMessage;
@@ -25,8 +24,7 @@ protected override void Act()
2524
var disconnect = _disconnectMessage.GetPacket(8, null);
2625
ServerSocket.Send(disconnect, 4, disconnect.Length - 4, SocketFlags.None);
2726

28-
// give session some time to process DisconnectMessage
29-
Thread.Sleep(200);
27+
Session.Disconnect();
3028
}
3129

3230
[TestMethod]

src/Renci.SshNet.Tests/Renci.SshNet.Tests.csproj

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@
254254
<Compile Include="Classes\SessionTest_Connected_ConnectionReset.cs" />
255255
<Compile Include="Classes\SessionTest_Connected_Disconnect.cs" />
256256
<Compile Include="Classes\SessionTest_Connected_GlobalRequestMessageAfterAuthenticationRace.cs" />
257+
<Compile Include="Classes\SessionTest_Connected_ServerAndClientDisconnectRace.cs" />
257258
<Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessage.cs" />
258259
<Compile Include="Classes\SessionTest_Connected_ServerSendsBadPacket.cs" />
259260
<Compile Include="Classes\SessionTest_Connected_ServerSendsDisconnectMessageAndShutsDownSocket.cs" />
@@ -535,12 +536,6 @@
535536
<Compile Include="Classes\Sftp\SftpSynchronizeDirectoriesAsyncResultTest.cs" />
536537
<Compile Include="Classes\Sftp\SftpUploadAsyncResultTest.cs" />
537538
</ItemGroup>
538-
<ItemGroup>
539-
<ProjectReference Include="..\Renci.SshNet\Renci.SshNet.csproj">
540-
<Project>{2F5F8C90-0BD1-424F-997C-7BC6280919D1}</Project>
541-
<Name>Renci.SshNet</Name>
542-
</ProjectReference>
543-
</ItemGroup>
544539
<ItemGroup>
545540
<EmbeddedResource Include="Properties\Resources.resx">
546541
<Generator>ResXFileCodeGenerator</Generator>
@@ -576,6 +571,12 @@
576571
<EmbeddedResource Include="Data\Key.SSH2.DSA.Encrypted.Des.CBC.12345.txt" />
577572
<EmbeddedResource Include="Data\Key.SSH2.DSA.txt" />
578573
</ItemGroup>
574+
<ItemGroup>
575+
<ProjectReference Include="..\Renci.SshNet\Renci.SshNet.csproj">
576+
<Project>{2f5f8c90-0bd1-424f-997c-7bc6280919d1}</Project>
577+
<Name>Renci.SshNet</Name>
578+
</ProjectReference>
579+
</ItemGroup>
579580
<Import Project="$(MSBuildBinPath)\Microsoft.CSharp.targets" />
580581
<!-- To modify your build process, add your task inside one of the targets below and uncomment it.
581582
Other similar extension points exist, see Microsoft.Common.targets.

src/Renci.SshNet/Session.NET.cs

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,6 @@ namespace Renci.SshNet
66
{
77
public partial class Session
88
{
9-
#if FEATURE_SOCKET_POLL
10-
/// <summary>
11-
/// Holds the lock object to ensure read access to the socket is synchronized.
12-
/// </summary>
13-
private readonly object _socketReadLock = new object();
14-
#endif // FEATURE_SOCKET_POLL
15-
169
#if FEATURE_SOCKET_POLL
1710
/// <summary>
1811
/// Gets a value indicating whether the socket is connected.
@@ -63,35 +56,28 @@ public partial class Session
6356
#endif
6457
partial void IsSocketConnected(ref bool isConnected)
6558
{
66-
isConnected = (_socket != null && _socket.Connected);
67-
#if FEATURE_SOCKET_POLL
68-
if (isConnected)
59+
DiagnosticAbstraction.Log(string.Format("[{0}] {1} Checking socket", ToHex(SessionId), DateTime.Now.Ticks));
60+
61+
lock (_socketDisposeLock)
6962
{
70-
// synchronize this to ensure thread B does not reset the wait handle before
71-
// thread A was able to check whether "bytes read from socket" signal was
72-
// actually received
73-
lock (_socketReadLock)
63+
#if FEATURE_SOCKET_POLL
64+
if (_socket == null || !_socket.Connected)
7465
{
75-
DiagnosticAbstraction.Log(string.Format("[{0}] {1} Checking socket", ToHex(SessionId), DateTime.Now.Ticks));
66+
isConnected = false;
67+
return;
68+
}
7669

77-
// reset waithandle, as we're only interested in reads that take
78-
// place between Poll and the Available check
79-
_bytesReadFromSocket.Reset();
80-
var connectionClosedOrDataAvailable = _socket.Poll(100, SelectMode.SelectRead);
70+
lock (_socketReadLock)
71+
{
72+
var connectionClosedOrDataAvailable = _socket.Poll(1, SelectMode.SelectRead);
8173
isConnected = !(connectionClosedOrDataAvailable && _socket.Available == 0);
82-
if (!isConnected)
83-
{
84-
// the race condition is between the Socket.Poll call and
85-
// Socket.Available, but the event handler - where we signal that
86-
// bytes have been received from the socket - is sometimes invoked
87-
// shortly after
88-
isConnected = _bytesReadFromSocket.WaitOne(500);
89-
}
90-
91-
DiagnosticAbstraction.Log(string.Format("[{0}] {1} Checked socket", ToHex(SessionId), DateTime.Now.Ticks));
9274
}
93-
}
75+
#else
76+
isConnected = _socket != null && _socket.Connected;
9477
#endif // FEATURE_SOCKET_POLL
78+
}
79+
80+
DiagnosticAbstraction.Log(string.Format("[{0}] {1} Checked socket", ToHex(SessionId), DateTime.Now.Ticks));
9581
}
9682
}
9783
}

0 commit comments

Comments
 (0)