Skip to content

Commit 6129aec

Browse files
SNOW-2299238: Renew idle session in pool if keepalive enabled (#1256)
Co-authored-by: Marcin Gemra <[email protected]>
1 parent 6ab6f88 commit 6129aec

File tree

5 files changed

+123
-10
lines changed

5 files changed

+123
-10
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,6 @@
99
- Upgraded AWS SDK library to v4.
1010
- Added the `changelog.yml` GitHub workflow to ensure changelog is updated on release PRs.
1111
- Removed internal classes from public API.
12-
- Added support for explicitly setting Azure managed identity client ID via `MANAGED_IDENTITY_CLIENT_ID` environment variable.
12+
- Added support for explicitly setting Azure managed identity client ID via `MANAGED_IDENTITY_CLIENT_ID` environmen
13+
- v5.0.1
14+
- Renew idle sessions in the pool if keep alive is enabled.

Snowflake.Data.Tests/Mock/MockRestSessionExpired.cs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ class MockRestSessionExpired : IMockRestRequester
1212
{
1313
static private readonly String EXPIRED_SESSION_TOKEN = "session_expired_token";
1414

15-
static private readonly String NEW_SESSION_TOKEN = "new_session_token";
15+
static internal readonly String NEW_SESSION_TOKEN = "new_session_token";
1616

1717
static private readonly String TOKEN_FMT = "Snowflake Token=\"{0}\"";
1818

19+
static internal readonly String THROW_ERROR_TOKEN = "throw_error_token";
20+
1921
static internal readonly int SESSION_EXPIRED_CODE = 390112;
2022

2123
public string FirstTimeRequestID;
@@ -88,8 +90,12 @@ public Task<T> PostAsync<T>(IRestRequest request, CancellationToken cancellation
8890
return Task.FromResult<T>((T)(object)queryExecResponse);
8991
}
9092
}
91-
else if (sfRequest.jsonBody is RenewSessionRequest)
93+
else if (sfRequest.jsonBody is RenewSessionRequest renewSessionRequest)
9294
{
95+
if (renewSessionRequest.oldSessionToken == THROW_ERROR_TOKEN)
96+
{
97+
throw new Exception("Error while renewing session");
98+
}
9399
return Task.FromResult<T>((T)(object)new RenewSessionResponse
94100
{
95101
success = true,
@@ -100,6 +106,20 @@ public Task<T> PostAsync<T>(IRestRequest request, CancellationToken cancellation
100106
}
101107
});
102108
}
109+
else if (sfRequest.jsonBody == null)
110+
{
111+
if (typeof(T) == typeof(CloseResponse))
112+
{
113+
return Task.FromResult<T>((T)(object)new CloseResponse
114+
{
115+
success = true
116+
});
117+
}
118+
return Task.FromResult<T>((T)(object)new NullDataResponse
119+
{
120+
success = true
121+
});
122+
}
103123
else
104124
{
105125
return Task.FromResult<T>((T)(object)null);

Snowflake.Data.Tests/UnitTests/Session/SessionPoolTest.cs

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using Snowflake.Data.Core;
77
using Snowflake.Data.Core.Session;
88
using Snowflake.Data.Core.Tools;
9+
using Snowflake.Data.Tests.Mock;
910
using Snowflake.Data.Tests.Util;
1011

1112
namespace Snowflake.Data.Tests.UnitTests.Session
@@ -320,9 +321,72 @@ public void TestShouldClearQueryContextCacheOnReturningToConnectionPool()
320321
Assert.AreEqual(0, session.GetQueryContextRequest().Entries.Count);
321322
}
322323

323-
private SFSession CreateSessionWithCurrentStartTime(string connectionString)
324+
[Test]
325+
public void TestShouldRenewSessionIfKeepAliveIsEnabled()
326+
{
327+
// arrange
328+
var connectionString = "account=testAccount;user=testUser;password=testPassword;";
329+
var session = CreateSessionWithCurrentStartTime(connectionString, new MockRestSessionExpired());
330+
session.startHeartBeatForThisSession();
331+
var pool = SessionPool.CreateSessionPool(connectionString, null, null, null);
332+
pool.SetTimeout(0);
333+
pool._idleSessions.Add(session);
334+
335+
// act
336+
pool.ExtractIdleSession(connectionString);
337+
338+
// assert
339+
Assert.AreEqual(MockRestSessionExpired.NEW_SESSION_TOKEN, session.sessionToken);
340+
}
341+
342+
[Test]
343+
public void TestShouldContinueExecutionIfRenewingFails()
344+
{
345+
// arrange
346+
var connectionString = "account=testAccount;user=testUser;password=testPassword;";
347+
var session = CreateSessionWithCurrentStartTime(connectionString, new MockRestSessionExpired());
348+
session.startHeartBeatForThisSession();
349+
var pool = SessionPool.CreateSessionPool(connectionString, null, null, null);
350+
pool.SetTimeout(0);
351+
pool._idleSessions.Add(session);
352+
session.sessionToken = MockRestSessionExpired.THROW_ERROR_TOKEN;
353+
354+
// act
355+
try
356+
{
357+
pool.ExtractIdleSession(connectionString);
358+
}
359+
catch
360+
{
361+
Assert.Fail("Should not throw exception even if session renewal fails");
362+
}
363+
364+
// assert
365+
Assert.AreNotEqual(MockRestSessionExpired.NEW_SESSION_TOKEN, session.sessionToken);
366+
}
367+
368+
[Test]
369+
public void TestShouldNotRenewSessionIfKeepAliveIsDisabled()
370+
{
371+
// arrange
372+
var connectionString = "account=testAccount;user=testUser;password=testPassword;";
373+
var session = CreateSessionWithCurrentStartTime(connectionString, new MockRestSessionExpired());
374+
session.stopHeartBeatForThisSession();
375+
var pool = SessionPool.CreateSessionPool(connectionString, null, null, null);
376+
pool.SetTimeout(0);
377+
pool._idleSessions.Add(session);
378+
379+
// act
380+
pool.ExtractIdleSession(connectionString);
381+
382+
// assert
383+
Assert.IsNull(session.sessionToken);
384+
}
385+
386+
private SFSession CreateSessionWithCurrentStartTime(string connectionString, IMockRestRequester restRequester = null)
324387
{
325-
var session = new SFSession(connectionString, new SessionPropertiesContext());
388+
var session = restRequester == null ? new SFSession(connectionString, new SessionPropertiesContext()) :
389+
new SFSession(connectionString, new SessionPropertiesContext(), restRequester);
326390
var now = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
327391
session.SetStartTime(now);
328392
return session;

Snowflake.Data/Core/Session/SFSession.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ internal class SFSession
6565
private readonly EasyLoggingStarter _easyLoggingStarter = EasyLoggingStarter.Instance;
6666

6767
private long _startTime = 0;
68+
69+
private long _timeSinceLastRenew = 0;
70+
6871
internal string ConnectionString { get; }
6972

7073
internal SessionPropertiesContext PropertiesContext { get; }
@@ -129,6 +132,7 @@ internal void ProcessLoginResponse(LoginResponse authnResponse)
129132
}
130133
logger.Debug($"Session opened: {sessionId}");
131134
_startTime = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
135+
_timeSinceLastRenew = _startTime;
132136
}
133137
else
134138
{
@@ -404,6 +408,7 @@ internal void renewSession()
404408
}
405409
else
406410
{
411+
_timeSinceLastRenew = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
407412
sessionToken = response.data.sessionToken;
408413
masterToken = response.data.masterToken;
409414
}
@@ -702,14 +707,15 @@ internal virtual bool IsNotOpen()
702707
internal virtual bool IsExpired(TimeSpan timeout, long utcTimeInMillis)
703708
{
704709
var hasEverBeenOpened = !IsNotOpen();
705-
return hasEverBeenOpened && TimeoutHelper.IsExpired(_startTime, utcTimeInMillis, timeout);
710+
return hasEverBeenOpened && TimeoutHelper.IsExpired(_timeSinceLastRenew, utcTimeInMillis, timeout);
706711
}
707712

708713
internal long GetStartTime() => _startTime;
709714

710715
internal void SetStartTime(long startTime)
711716
{
712717
_startTime = startTime;
718+
_timeSinceLastRenew = _startTime;
713719
}
714720

715721
internal void ReplaceAuthenticator(IAuthenticator authenticator)

Snowflake.Data/Core/Session/SessionPool.cs

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ sealed class SessionPool : IDisposable
1919
private static ISessionFactory s_sessionFactory = new SessionFactory();
2020

2121
private readonly Guid _id = Guid.NewGuid();
22-
private readonly List<SFSession> _idleSessions;
22+
internal readonly List<SFSession> _idleSessions;
2323
private readonly IWaitingQueue _waitingForIdleSessionQueue;
2424
private readonly ISessionCreationTokenCounter _sessionCreationTokenCounter;
2525
private readonly ISessionCreationTokenCounter _noPoolingSessionCreationTokenCounter = new NonCountingSessionCreationTokenCounter();
@@ -358,7 +358,7 @@ private SFSession WaitForSession(string connStr)
358358

359359
private static Exception WaitingFailedException() => new Exception("Could not obtain a connection from the pool within a given timeout");
360360

361-
private SFSession ExtractIdleSession(string connStr)
361+
internal SFSession ExtractIdleSession(string connStr)
362362
{
363363
for (int i = 0; i < _idleSessions.Count; i++)
364364
{
@@ -369,8 +369,29 @@ private SFSession ExtractIdleSession(string connStr)
369369
var timeNow = DateTimeOffset.UtcNow.ToUnixTimeMilliseconds();
370370
if (session.IsExpired(_poolConfig.ExpirationTimeout, timeNow))
371371
{
372-
Task.Run(() => session.close());
373-
i--;
372+
s_logger.Debug($"session {session.sessionId}{PoolIdentification()} has expired");
373+
if (session.isHeartBeatEnabled)
374+
{
375+
s_logger.Debug($"keep alive enabled: renewing session {session.sessionId}{PoolIdentification()}");
376+
try
377+
{
378+
session.renewSession();
379+
_busySessionsCounter.Increase();
380+
return session;
381+
}
382+
catch (Exception e)
383+
{
384+
s_logger.Error($"failed to renew session {session.sessionId}{PoolIdentification()} due to error: {e.Message}");
385+
Task.Run(() => session.close());
386+
i--;
387+
}
388+
}
389+
else
390+
{
391+
s_logger.Debug($"keep alive disabled: closing the expired session {session.sessionId}{PoolIdentification()}");
392+
Task.Run(() => session.close());
393+
i--;
394+
}
374395
}
375396
else
376397
{

0 commit comments

Comments
 (0)