Skip to content

Commit 28e6742

Browse files
authored
fix ConnectAsync not respecting the connection timeout (#1502)
1 parent 4c5d0c0 commit 28e6742

File tree

2 files changed

+84
-1
lines changed

2 files changed

+84
-1
lines changed

src/Renci.SshNet/BaseClient.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,17 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
307307
DisposeSession(session);
308308
}
309309

310-
Session = await CreateAndConnectSessionAsync(cancellationToken).ConfigureAwait(false);
310+
using var timeoutCancellationTokenSource = new CancellationTokenSource(ConnectionInfo.Timeout);
311+
using var linkedCancellationTokenSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCancellationTokenSource.Token);
312+
313+
try
314+
{
315+
Session = await CreateAndConnectSessionAsync(linkedCancellationTokenSource.Token).ConfigureAwait(false);
316+
}
317+
catch (OperationCanceledException ex) when (timeoutCancellationTokenSource.IsCancellationRequested)
318+
{
319+
throw new SshOperationTimeoutException("Connection has timed out.", ex);
320+
}
311321
}
312322

313323
try
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
using System;
2+
using System.Threading;
3+
using System.Threading.Tasks;
4+
5+
using Microsoft.VisualStudio.TestTools.UnitTesting;
6+
7+
using Moq;
8+
9+
#if !NET8_0_OR_GREATER
10+
using Renci.SshNet.Abstractions;
11+
#endif
12+
using Renci.SshNet.Common;
13+
using Renci.SshNet.Connection;
14+
15+
namespace Renci.SshNet.Tests.Classes
16+
{
17+
[TestClass]
18+
public class BaseClientTest_ConnectAsync_Timeout
19+
{
20+
private BaseClient _client;
21+
22+
[TestInitialize]
23+
public void Init()
24+
{
25+
var sessionMock = new Mock<ISession>();
26+
var serviceFactoryMock = new Mock<IServiceFactory>();
27+
var socketFactoryMock = new Mock<ISocketFactory>();
28+
29+
sessionMock.Setup(p => p.ConnectAsync(It.IsAny<CancellationToken>()))
30+
.Returns<CancellationToken>(c => Task.Delay(Timeout.Infinite, c));
31+
32+
serviceFactoryMock.Setup(p => p.CreateSocketFactory())
33+
.Returns(socketFactoryMock.Object);
34+
35+
var connectionInfo = new ConnectionInfo("host", "user", new PasswordAuthenticationMethod("user", "pwd"))
36+
{
37+
Timeout = TimeSpan.FromSeconds(1)
38+
};
39+
40+
serviceFactoryMock.Setup(p => p.CreateSession(connectionInfo, socketFactoryMock.Object))
41+
.Returns(sessionMock.Object);
42+
43+
_client = new MyClient(connectionInfo, false, serviceFactoryMock.Object);
44+
}
45+
46+
[TestMethod]
47+
public async Task ConnectAsyncWithTimeoutThrowsSshTimeoutException()
48+
{
49+
await Assert.ThrowsExceptionAsync<SshOperationTimeoutException>(() => _client.ConnectAsync(CancellationToken.None));
50+
}
51+
52+
[TestMethod]
53+
public async Task ConnectAsyncWithCancelledTokenThrowsOperationCancelledException()
54+
{
55+
using var cancellationTokenSource = new CancellationTokenSource();
56+
await cancellationTokenSource.CancelAsync();
57+
await Assert.ThrowsExceptionAsync<OperationCanceledException>(() => _client.ConnectAsync(cancellationTokenSource.Token));
58+
}
59+
60+
[TestCleanup]
61+
public void Cleanup()
62+
{
63+
_client?.Dispose();
64+
}
65+
66+
private class MyClient : BaseClient
67+
{
68+
public MyClient(ConnectionInfo connectionInfo, bool ownsConnectionInfo, IServiceFactory serviceFactory) : base(connectionInfo, ownsConnectionInfo, serviceFactory)
69+
{
70+
}
71+
}
72+
}
73+
}

0 commit comments

Comments
 (0)