Skip to content

Commit 9acd180

Browse files
SNOW-2038030: Add check for expired session when getting result with ID (#1182)
1 parent 85e437d commit 9acd180

File tree

3 files changed

+110
-3
lines changed

3 files changed

+110
-3
lines changed

Snowflake.Data.Tests/Mock/MockRestSessionExpired.cs

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,11 @@ class MockRestSessionExpired : IMockRestRequester
1212
{
1313
static private readonly String EXPIRED_SESSION_TOKEN = "session_expired_token";
1414

15+
static private readonly String NEW_SESSION_TOKEN = "new_session_token";
16+
1517
static private readonly String TOKEN_FMT = "Snowflake Token=\"{0}\"";
1618

17-
static private readonly int SESSION_EXPIRED_CODE = 390112;
19+
static internal readonly int SESSION_EXPIRED_CODE = 390112;
1820

1921
public string FirstTimeRequestID;
2022

@@ -54,7 +56,7 @@ public Task<T> PostAsync<T>(IRestRequest request, CancellationToken cancellation
5456
};
5557
return Task.FromResult<T>((T)(object)queryExecResponse);
5658
}
57-
else if (sfRequest.authorizationToken.Equals(String.Format(TOKEN_FMT, "new_session_token")))
59+
else if (sfRequest.authorizationToken.Equals(String.Format(TOKEN_FMT, NEW_SESSION_TOKEN)))
5860
{
5961
SecondTimeRequestID = ExtractRequestID(sfRequest.Url.Query);
6062
QueryExecResponse queryExecResponse = new QueryExecResponse
@@ -93,7 +95,7 @@ public Task<T> PostAsync<T>(IRestRequest request, CancellationToken cancellation
9395
success = true,
9496
data = new RenewSessionResponseData()
9597
{
96-
sessionToken = "new_session_token",
98+
sessionToken = NEW_SESSION_TOKEN,
9799
masterToken = "new_master_token"
98100
}
99101
});
@@ -116,6 +118,46 @@ public T Get<T>(IRestRequest request)
116118

117119
public Task<T> GetAsync<T>(IRestRequest request, CancellationToken cancellationToken)
118120
{
121+
SFRestRequest sfRequest = (SFRestRequest)request;
122+
if (sfRequest.Url.ToString().Contains("retryId"))
123+
{
124+
QueryExecResponse queryExecResponse = new QueryExecResponse
125+
{
126+
success = false,
127+
code = SESSION_EXPIRED_CODE
128+
};
129+
return Task.FromResult<T>((T)(object)queryExecResponse);
130+
}
131+
if (sfRequest.authorizationToken.Equals(String.Format(TOKEN_FMT, EXPIRED_SESSION_TOKEN)))
132+
{
133+
QueryExecResponse queryExecResponse = new QueryExecResponse
134+
{
135+
success = false,
136+
code = SESSION_EXPIRED_CODE
137+
};
138+
return Task.FromResult<T>((T)(object)queryExecResponse);
139+
}
140+
if (sfRequest.authorizationToken.Equals(String.Format(TOKEN_FMT, NEW_SESSION_TOKEN)))
141+
{
142+
QueryExecResponse queryExecResponse = new QueryExecResponse
143+
{
144+
success = true,
145+
data = new QueryExecResponseData
146+
{
147+
rowSet = new string[,] { { "abc" } },
148+
rowType = new List<ExecResponseRowType>()
149+
{
150+
new ExecResponseRowType
151+
{
152+
name = "colOne",
153+
type = SFDataType.TEXT.ToString()
154+
}
155+
},
156+
parameters = new List<NameValueParameter>()
157+
}
158+
};
159+
return Task.FromResult<T>((T)(object)queryExecResponse);
160+
}
119161
return Task.FromResult<T>((T)(object)null);
120162
}
121163

Snowflake.Data.Tests/UnitTests/SFStatementTest.cs

100755100644
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Snowflake.Data.Core;
44
using NUnit.Framework;
55
using System;
6+
using System.Threading.Tasks;
67
using Snowflake.Data.Core.Session;
78

89
namespace Snowflake.Data.Tests.UnitTests
@@ -29,6 +30,56 @@ public void TestSessionRenew()
2930
Assert.AreEqual(restRequester.FirstTimeRequestID, restRequester.SecondTimeRequestID);
3031
}
3132

33+
[Test]
34+
public void TestSessionRenewGetResultWithId()
35+
{
36+
Mock.MockRestSessionExpired restRequester = new Mock.MockRestSessionExpired();
37+
SFSession sfSession = new SFSession("account=test;user=test;password=test", new SessionPropertiesContext(), restRequester);
38+
sfSession.Open();
39+
SFStatement statement = new SFStatement(sfSession);
40+
SFBaseResultSet resultSet = statement.GetResultWithId("mockId");
41+
Assert.AreEqual(true, resultSet.Next());
42+
Assert.AreEqual("abc", resultSet.GetString(0));
43+
Assert.AreEqual("new_session_token", sfSession.sessionToken);
44+
Assert.AreEqual("new_master_token", sfSession.masterToken);
45+
}
46+
47+
[Test]
48+
public void TestSessionRenewGetResultWithIdOnlyRetries3Times()
49+
{
50+
Mock.MockRestSessionExpired restRequester = new Mock.MockRestSessionExpired();
51+
SFSession sfSession = new SFSession("account=test;user=test;password=test", new SessionPropertiesContext(), restRequester);
52+
sfSession.Open();
53+
SFStatement statement = new SFStatement(sfSession);
54+
var thrown = Assert.Throws<SnowflakeDbException>(() => statement.GetResultWithId("retryId"));
55+
Assert.AreEqual(thrown.ErrorCode, Mock.MockRestSessionExpired.SESSION_EXPIRED_CODE);
56+
}
57+
58+
[Test]
59+
public async Task TestSessionRenewGetResultWithIdAsync()
60+
{
61+
Mock.MockRestSessionExpired restRequester = new Mock.MockRestSessionExpired();
62+
SFSession sfSession = new SFSession("account=test;user=test;password=test", new SessionPropertiesContext(), restRequester);
63+
await sfSession.OpenAsync(CancellationToken.None);
64+
SFStatement statement = new SFStatement(sfSession);
65+
SFBaseResultSet resultSet = await statement.GetResultWithIdAsync("mockId", CancellationToken.None);
66+
Assert.AreEqual(true, resultSet.Next());
67+
Assert.AreEqual("abc", resultSet.GetString(0));
68+
Assert.AreEqual("new_session_token", sfSession.sessionToken);
69+
Assert.AreEqual("new_master_token", sfSession.masterToken);
70+
}
71+
72+
[Test]
73+
public async Task TestSessionRenewGetResultWithIdOnlyRetries3TimesAsync()
74+
{
75+
Mock.MockRestSessionExpired restRequester = new Mock.MockRestSessionExpired();
76+
SFSession sfSession = new SFSession("account=test;user=test;password=test", new SessionPropertiesContext(), restRequester);
77+
await sfSession.OpenAsync(CancellationToken.None);
78+
SFStatement statement = new SFStatement(sfSession);
79+
var thrown = Assert.ThrowsAsync<SnowflakeDbException>(async () => await statement.GetResultWithIdAsync("retryId", CancellationToken.None));
80+
Assert.AreEqual(thrown.ErrorCode, Mock.MockRestSessionExpired.SESSION_EXPIRED_CODE);
81+
}
82+
3283
// Mock test for session renew during query execution
3384
[Test]
3485
public void TestSessionRenewDuringQueryExec()

Snowflake.Data/Core/SFStatement.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ class SFStatement
113113

114114
private const int SF_QUERY_IN_PROGRESS_ASYNC = 333334;
115115

116+
private const int GetResultWithIdMaxRetriesCount = 3;
117+
116118
private string _requestId;
117119

118120
private readonly object _requestIdLock = new object();
@@ -573,6 +575,12 @@ internal async Task<SFBaseResultSet> GetResultWithIdAsync(string resultId, Cance
573575
var req = BuildResultRequestWithId(resultId);
574576
QueryExecResponse response = null;
575577
response = await _restRequester.GetAsync<QueryExecResponse>(req, cancellationToken).ConfigureAwait(false);
578+
for (var retryCount = 0; retryCount < GetResultWithIdMaxRetriesCount && SessionExpired(response); retryCount++)
579+
{
580+
await SfSession.renewSessionAsync(cancellationToken).ConfigureAwait(false);
581+
req.authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, SfSession.sessionToken);
582+
response = await _restRequester.GetAsync<QueryExecResponse>(req, cancellationToken).ConfigureAwait(false);
583+
}
576584
return BuildResultSet(response, cancellationToken);
577585
}
578586

@@ -581,6 +589,12 @@ internal SFBaseResultSet GetResultWithId(string resultId)
581589
var req = BuildResultRequestWithId(resultId);
582590
QueryExecResponse response = null;
583591
response = _restRequester.Get<QueryExecResponse>(req);
592+
for (var retryCount = 0; retryCount < GetResultWithIdMaxRetriesCount && SessionExpired(response); retryCount++)
593+
{
594+
SfSession.renewSession();
595+
req.authorizationToken = string.Format(SF_AUTHORIZATION_SNOWFLAKE_FMT, SfSession.sessionToken);
596+
response = _restRequester.Get<QueryExecResponse>(req);
597+
}
584598
return BuildResultSet(response, CancellationToken.None);
585599
}
586600

0 commit comments

Comments
 (0)