Skip to content

Commit cc1cbd0

Browse files
Merge branch 'master' into NO-SNOW-jenkins-trigger
2 parents 9c713a7 + b0fd5cb commit cc1cbd0

File tree

6 files changed

+167
-15
lines changed

6 files changed

+167
-15
lines changed

Snowflake.Data.Tests/IntegrationTests/SFMultiStatementsIT.cs

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ public void TestSelectWithoutBinding()
7979
conn.Close();
8080
}
8181
}
82-
82+
8383
[Test]
8484
public async Task TestSelectAsync()
8585
{
@@ -120,7 +120,7 @@ public async Task TestSelectAsync()
120120
conn.Close();
121121
}
122122
}
123-
123+
124124
[Test]
125125
public void TestSelectWithBinding()
126126
{
@@ -236,17 +236,17 @@ public void TestMixedQueryTypeWithBinding()
236236
DbDataReader reader = cmd.ExecuteReader();
237237

238238
// result of create
239-
Assert.IsFalse(reader.HasRows);
239+
Assert.IsTrue(reader.HasRows);
240240
Assert.AreEqual(0, reader.RecordsAffected);
241241

242242
// result of insert #1
243243
Assert.IsTrue(reader.NextResult());
244-
Assert.IsFalse(reader.HasRows);
244+
Assert.IsTrue(reader.HasRows);
245245
Assert.AreEqual(1, reader.RecordsAffected);
246246

247247
// result of insert #2
248248
Assert.IsTrue(reader.NextResult());
249-
Assert.IsFalse(reader.HasRows);
249+
Assert.IsTrue(reader.HasRows);
250250
Assert.AreEqual(2, reader.RecordsAffected);
251251

252252
// result of select
@@ -266,7 +266,7 @@ public void TestMixedQueryTypeWithBinding()
266266

267267
// result of drop
268268
Assert.IsTrue(reader.NextResult());
269-
Assert.IsFalse(reader.HasRows);
269+
Assert.IsTrue(reader.HasRows);
270270
Assert.AreEqual(0, reader.RecordsAffected);
271271

272272
Assert.IsFalse(reader.NextResult());
@@ -382,7 +382,7 @@ public void TestWithAllQueryTypes()
382382

383383
// result of create
384384
Assert.IsTrue(reader.NextResult());
385-
Assert.IsFalse(reader.HasRows);
385+
Assert.IsTrue(reader.HasRows);
386386
Assert.AreEqual(0, reader.RecordsAffected);
387387

388388
// result of explain
@@ -400,7 +400,7 @@ public void TestWithAllQueryTypes()
400400

401401
// result of insert
402402
Assert.IsTrue(reader.NextResult());
403-
Assert.IsFalse(reader.HasRows);
403+
Assert.IsTrue(reader.HasRows);
404404
Assert.AreEqual(1, reader.RecordsAffected);
405405

406406
// result of describe
@@ -420,7 +420,7 @@ public void TestWithAllQueryTypes()
420420

421421
// result of create
422422
Assert.IsTrue(reader.NextResult());
423-
Assert.IsFalse(reader.HasRows);
423+
Assert.IsTrue(reader.HasRows);
424424
Assert.AreEqual(0, reader.RecordsAffected);
425425

426426
// result of call
@@ -434,7 +434,7 @@ public void TestWithAllQueryTypes()
434434

435435
// result of use
436436
Assert.IsTrue(reader.NextResult());
437-
Assert.IsFalse(reader.HasRows);
437+
Assert.IsTrue(reader.HasRows);
438438
Assert.AreEqual(0, reader.RecordsAffected);
439439

440440
Assert.IsFalse(reader.NextResult());
@@ -531,5 +531,63 @@ public void TestWithMultipleStatementSetting()
531531
conn.Close();
532532
}
533533
}
534+
535+
[Test]
536+
public void TestResultSetReturnedForAllQueryTypes()
537+
{
538+
using (DbConnection conn = new SnowflakeDbConnection())
539+
{
540+
conn.ConnectionString = ConnectionString;
541+
conn.Open();
542+
543+
using (DbCommand cmd = conn.CreateCommand())
544+
{
545+
cmd.CommandText = "set query_tag = (select 'dummy_tag');" +
546+
"alter session set query_tag='dummy_tag';" +
547+
"select 1;" +
548+
$"create or replace temporary table {TableName}(c1 varchar);" +
549+
$"explain using text select * from {TableName};" +
550+
"show parameters;" +
551+
$"insert into {TableName} values ('str1');" +
552+
$"update {TableName} set c1 = 'str2';" +
553+
$"select * from {TableName};" +
554+
$"desc table {TableName};" +
555+
$"copy into @%{TableName} from {TableName};" +
556+
$"list @%{TableName};" +
557+
$"remove @%{TableName};" +
558+
"create or replace temporary procedure P1() returns varchar language javascript as $$ return ''; $$;" +
559+
"call p1();" +
560+
$"use role {testConfig.role}";
561+
562+
var stmtCount = 16;
563+
564+
// Set statement count
565+
var stmtCountParam = cmd.CreateParameter();
566+
stmtCountParam.ParameterName = "MULTI_STATEMENT_COUNT";
567+
stmtCountParam.DbType = DbType.Int16;
568+
stmtCountParam.Value = stmtCount;
569+
cmd.Parameters.Add(stmtCountParam);
570+
571+
DbDataReader reader = cmd.ExecuteReader();
572+
573+
// at least one row in the first result set
574+
Assert.IsTrue(reader.HasRows);
575+
Assert.IsTrue(reader.Read());
576+
577+
for (int i = 1; i < stmtCount; i++)
578+
{
579+
Assert.IsTrue(reader.NextResult());
580+
581+
// at least one row in subsequent result sets
582+
Assert.IsTrue(reader.HasRows);
583+
Assert.IsTrue(reader.Read());
584+
}
585+
Assert.IsFalse(reader.NextResult());
586+
reader.Close();
587+
}
588+
589+
conn.Close();
590+
}
591+
}
534592
}
535593
}
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
using System;
2+
using NUnit.Framework;
3+
using Moq;
4+
using Snowflake.Data.Configuration;
5+
using Snowflake.Data.Core.Tools;
6+
7+
namespace Snowflake.Data.Tests.UnitTests.Configuration
8+
{
9+
[TestFixture]
10+
public class ClientFeatureFlagsTest
11+
{
12+
[Test]
13+
[TestCase(ClientFeatureFlags.EnabledExperimentalAuthenticationVariableName, "true", true)]
14+
[TestCase(ClientFeatureFlags.EnabledExperimentalAuthenticationVariableName, "TRUE", true)]
15+
[TestCase(ClientFeatureFlags.EnabledExperimentalAuthenticationVariableName, "false", false)]
16+
[TestCase(ClientFeatureFlags.EnabledExperimentalAuthenticationVariableName, "", false)]
17+
[TestCase(ClientFeatureFlags.EnabledExperimentalAuthenticationVariableName, null, false)]
18+
[TestCase(ClientFeatureFlags.EnabledExperimentalAuthenticationVariableName, "not a bool value", false)]
19+
[TestCase("OTHER_VARIABLE_NAME", "true", false)]
20+
public void TestEnabledExperimentalAuthentication(string variableName, string variableValue, bool expectedValue) {
21+
// arrange
22+
var environmentOperations = new Mock<EnvironmentOperations>();
23+
environmentOperations
24+
.Setup(e => e.GetEnvironmentVariable(variableName))
25+
.Returns(variableValue);
26+
27+
// act
28+
var clientFeatures = new ClientFeatureFlags(environmentOperations.Object);
29+
30+
// assert
31+
Assert.AreEqual(expectedValue, clientFeatures.IsEnabledExperimentalAuthentication);
32+
}
33+
34+
[Test]
35+
public void TestDisabledExperimentalAuthenticationWhenCouldNotReadEnvVariable()
36+
{
37+
// arrange
38+
var environmentOperations = new Mock<EnvironmentOperations>();
39+
environmentOperations
40+
.Setup(e => e.GetEnvironmentVariable(ClientFeatureFlags.EnabledExperimentalAuthenticationVariableName))
41+
.Throws(() => new Exception("Could not read environmental variable"));
42+
43+
// act
44+
var clientFeatures = new ClientFeatureFlags(environmentOperations.Object);
45+
46+
// assert
47+
Assert.IsFalse(clientFeatures.IsEnabledExperimentalAuthentication);
48+
}
49+
}
50+
}

Snowflake.Data/Client/SnowflakeDbCommand.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ public override int ExecuteNonQuery()
166166
long total = 0;
167167
do
168168
{
169-
if (resultSet.HasResultSet()) continue;
169+
if (resultSet.IsDQL()) continue;
170170
int count = resultSet.CalculateUpdateCount();
171171
if (count < 0)
172172
{
@@ -193,7 +193,7 @@ public override async Task<int> ExecuteNonQueryAsync(CancellationToken cancellat
193193
long total = 0;
194194
do
195195
{
196-
if (resultSet.HasResultSet()) continue;
196+
if (resultSet.IsDQL()) continue;
197197
int count = resultSet.CalculateUpdateCount();
198198
if (count < 0)
199199
{

Snowflake.Data/Client/SnowflakeDbDataReader.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public override bool HasRows
7575
{
7676
get
7777
{
78-
return resultSet.HasResultSet() && resultSet.HasRows();
78+
return !resultSet.isClosed && resultSet.HasRows();
7979
}
8080
}
8181

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using System;
2+
using Snowflake.Data.Core.Tools;
3+
using Snowflake.Data.Log;
4+
5+
namespace Snowflake.Data.Configuration
6+
{
7+
internal class ClientFeatureFlags
8+
{
9+
private static readonly SFLogger s_logger = SFLoggerFactory.GetLogger<ClientFeatureFlags>();
10+
public bool IsEnabledExperimentalAuthentication { get; set; }
11+
12+
public static readonly ClientFeatureFlags Instance = new ClientFeatureFlags(EnvironmentOperations.Instance);
13+
14+
internal const string EnabledExperimentalAuthenticationVariableName = "SF_ENABLE_EXPERIMENTAL_AUTHENTICATION";
15+
private const bool EnabledExperimentalAuthenticationDefaultValue = false;
16+
17+
internal ClientFeatureFlags(EnvironmentOperations environmentOperations)
18+
{
19+
IsEnabledExperimentalAuthentication = ReadEnabledExperimentalAuthentication(environmentOperations);
20+
}
21+
22+
private bool ReadEnabledExperimentalAuthentication(EnvironmentOperations environmentOperations)
23+
{
24+
try
25+
{
26+
var isEnabledString = environmentOperations.GetEnvironmentVariable(EnabledExperimentalAuthenticationVariableName);
27+
if (string.IsNullOrEmpty(isEnabledString))
28+
{
29+
s_logger.Debug($"Variable '{EnabledExperimentalAuthenticationVariableName}' not set. Using the default value: {EnabledExperimentalAuthenticationDefaultValue}");
30+
return EnabledExperimentalAuthenticationDefaultValue;
31+
}
32+
var isEnabled = bool.Parse(isEnabledString);
33+
s_logger.Debug($"Variable '{EnabledExperimentalAuthenticationVariableName}' was read as: {isEnabled}");
34+
return isEnabled;
35+
}
36+
catch (Exception exception)
37+
{
38+
s_logger.Error($"Could not get or parse '{EnabledExperimentalAuthenticationVariableName}' variable. Used the default value: {EnabledExperimentalAuthenticationDefaultValue}.", exception);
39+
return EnabledExperimentalAuthenticationDefaultValue;
40+
}
41+
}
42+
}
43+
}

Snowflake.Data/Core/ResultSetUtil.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ internal static int CalculateUpdateCount(this SFBaseResultSet resultSet)
2626
{
2727
updateCount += resultSet.GetInt64(i);
2828
}
29-
29+
resultSet.Rewind();
3030
break;
3131
case SFStatementType.COPY:
3232
var index = resultSet.sfResultSetMetaData.GetColumnIndexByName("rows_loaded");
@@ -49,6 +49,7 @@ internal static int CalculateUpdateCount(this SFBaseResultSet resultSet)
4949
}
5050
break;
5151
case SFStatementType.SELECT:
52+
// DbDataReader.RecordsAffected returns -1 for SELECT statement
5253
updateCount = -1;
5354
break;
5455
default:
@@ -62,7 +63,7 @@ internal static int CalculateUpdateCount(this SFBaseResultSet resultSet)
6263
return (int)updateCount;
6364
}
6465

65-
internal static bool HasResultSet(this SFBaseResultSet resultSet)
66+
internal static bool IsDQL(this SFBaseResultSet resultSet)
6667
{
6768
if (resultSet.isClosed) return false;
6869

0 commit comments

Comments
 (0)