Skip to content

Commit 4906b64

Browse files
committed
Immediately sent identification string when connection is established.
Fixes #689.
1 parent ab827c2 commit 4906b64

File tree

2 files changed

+71
-18
lines changed

2 files changed

+71
-18
lines changed

src/Renci.SshNet.Tests/Classes/SessionTest.cs

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
using System.Net;
44
using System.Net.Sockets;
55
using System.Text;
6+
using System.Threading;
67
using Microsoft.VisualStudio.TestTools.UnitTesting;
78
using Moq;
89
using Renci.SshNet.Common;
10+
using Renci.SshNet.Messages.Transport;
911
using Renci.SshNet.Tests.Common;
1012
using Renci.SshNet.Tests.Properties;
1113

@@ -98,6 +100,53 @@ public void ConnectShouldSkipLinesBeforeProtocolIdentificationString()
98100
}
99101
}
100102

103+
[TestMethod]
104+
public void ConnectShouldImmediatelySendIdentificationStringWhenConnectionHasBeenEstablised()
105+
{
106+
var serverEndPoint = new IPEndPoint(IPAddress.Loopback, 8122);
107+
var connectionInfo = CreateConnectionInfo(serverEndPoint, TimeSpan.FromSeconds(5));
108+
109+
using (var serverStub = new AsyncSocketListener(serverEndPoint))
110+
{
111+
serverStub.Connected += socket =>
112+
{
113+
var identificationBytes = new byte[2048];
114+
var bytesReceived = socket.Receive(identificationBytes);
115+
116+
if (bytesReceived > 0)
117+
{
118+
var identificationSttring = Encoding.ASCII.GetString(identificationBytes, 0, bytesReceived);
119+
Console.WriteLine("STRING=" + identificationSttring);
120+
Console.WriteLine("DONE");
121+
122+
socket.Send(Encoding.ASCII.GetBytes("\r\n"));
123+
socket.Send(Encoding.ASCII.GetBytes("WELCOME banner\r\n"));
124+
socket.Send(Encoding.ASCII.GetBytes("SSH-666-SshStub\r\n"));
125+
}
126+
127+
socket.Shutdown(SocketShutdown.Send);
128+
};
129+
serverStub.Start();
130+
131+
using (var session = new Session(connectionInfo, _serviceFactoryMock.Object))
132+
{
133+
try
134+
{
135+
session.Connect();
136+
Assert.Fail();
137+
}
138+
catch (SshConnectionException ex)
139+
{
140+
Assert.IsNull(ex.InnerException);
141+
Assert.AreEqual("Server version '666' is not supported.", ex.Message);
142+
143+
Assert.AreEqual("SSH-666-SshStub", connectionInfo.ServerVersion);
144+
}
145+
}
146+
}
147+
}
148+
149+
101150
[TestMethod]
102151
public void ConnectShouldSupportProtocolIdentificationStringThatDoesNotEndWithCrlf()
103152
{

src/Renci.SshNet/Session.cs

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -577,14 +577,14 @@ public void Connect()
577577

578578
lock (this)
579579
{
580-
// If connected don't connect again
580+
// If connected don't connect again
581581
if (IsConnected)
582582
return;
583583

584-
// reset connection specific information
584+
// Reset connection specific information
585585
Reset();
586586

587-
// Build list of available messages while connecting
587+
// Build list of available messages while connecting
588588
_sshMessageFactory = new SshMessageFactory();
589589

590590
switch (ConnectionInfo.ProxyType)
@@ -606,10 +606,14 @@ public void Connect()
606606
break;
607607
}
608608

609+
// Immediately send the identification string since the spec states both sides MUST send an identification string
610+
// when the connection has been established
611+
SocketAbstraction.Send(_socket, Encoding.UTF8.GetBytes(string.Format(CultureInfo.InvariantCulture, "{0}\x0D\x0A", ClientVersion)));
612+
609613
Match versionMatch;
610614

611-
// Get server version from the server,
612-
// ignore text lines which are sent before if any
615+
// Get server version from the server,
616+
// ignore text lines which are sent before if any
613617
while (true)
614618
{
615619
var serverVersion = SocketReadLine(_socket, ConnectionInfo.Timeout);
@@ -623,11 +627,11 @@ public void Connect()
623627
}
624628
}
625629

626-
// Set connection versions
630+
// Set connection versions
627631
ConnectionInfo.ServerVersion = ServerVersion;
628632
ConnectionInfo.ClientVersion = ClientVersion;
629633

630-
// Get server SSH version
634+
// Get server SSH version
631635
var version = versionMatch.Result("${protoversion}");
632636

633637
var softwareName = versionMatch.Result("${softwareversion}");
@@ -639,9 +643,7 @@ public void Connect()
639643
throw new SshConnectionException(string.Format(CultureInfo.CurrentCulture, "Server version '{0}' is not supported.", version), DisconnectReason.ProtocolVersionNotSupported);
640644
}
641645

642-
SocketAbstraction.Send(_socket, Encoding.UTF8.GetBytes(string.Format(CultureInfo.InvariantCulture, "{0}\x0D\x0A", ClientVersion)));
643-
644-
// Register Transport response messages
646+
// Register Transport response messages
645647
RegisterMessage("SSH_MSG_DISCONNECT");
646648
RegisterMessage("SSH_MSG_IGNORE");
647649
RegisterMessage("SSH_MSG_UNIMPLEMENTED");
@@ -650,29 +652,29 @@ public void Connect()
650652
RegisterMessage("SSH_MSG_KEXINIT");
651653
RegisterMessage("SSH_MSG_NEWKEYS");
652654

653-
// Some server implementations might sent this message first, prior establishing encryption algorithm
655+
// Some server implementations might sent this message first, prior to establishing encryption algorithm
654656
RegisterMessage("SSH_MSG_USERAUTH_BANNER");
655657

656-
// mark the message listener threads as started
658+
// Mark the message listener threads as started
657659
_messageListenerCompleted.Reset();
658660

659-
// Start incoming request listener
661+
// Start incoming request listener
660662
ThreadAbstraction.ExecuteThread(() => MessageListener());
661663

662-
// Wait for key exchange to be completed
664+
// Wait for key exchange to be completed
663665
WaitOnHandle(_keyExchangeCompletedWaitHandle);
664666

665-
// If sessionId is not set then its not connected
667+
// If sessionId is not set then its not connected
666668
if (SessionId == null)
667669
{
668670
Disconnect();
669671
return;
670672
}
671673

672-
// Request user authorization service
674+
// Request user authorization service
673675
SendMessage(new ServiceRequestMessage(ServiceName.UserAuthentication));
674676

675-
// Wait for service to be accepted
677+
// Wait for service to be accepted
676678
WaitOnHandle(_serviceAccepted);
677679

678680
if (string.IsNullOrEmpty(ConnectionInfo.Username))
@@ -687,7 +689,7 @@ public void Connect()
687689
ConnectionInfo.Authenticate(this, _serviceFactory);
688690
_isAuthenticated = true;
689691

690-
// Register Connection messages
692+
// Register Connection messages
691693
RegisterMessage("SSH_MSG_REQUEST_SUCCESS");
692694
RegisterMessage("SSH_MSG_REQUEST_FAILURE");
693695
RegisterMessage("SSH_MSG_CHANNEL_OPEN_CONFIRMATION");
@@ -2003,6 +2005,8 @@ private void MessageListener()
20032005
break;
20042006
}
20052007

2008+
Console.WriteLine("RECEIVED MESSAGe " + message.GetType());
2009+
20062010
// process message
20072011
message.Process(this);
20082012
}

0 commit comments

Comments
 (0)