Skip to content

Commit 055efb2

Browse files
SNOW-2102617 Enable result sets for DML (#1169)
1 parent 866d376 commit 055efb2

File tree

4 files changed

+74
-15
lines changed

4 files changed

+74
-15
lines changed

Snowflake.Data.Tests/IntegrationTests/SFMultiStatementsIT.cs

Lines changed: 68 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ public void TestSelectWithoutBinding()
7979
conn.Close();
8080
}
8181
}
82-
82+
8383
[Test]
8484
public async Task TestSelectAsync()
8585
{
@@ -120,7 +120,7 @@ public async Task TestSelectAsync()
120120
conn.Close();
121121
}
122122
}
123-
123+
124124
[Test]
125125
public void TestSelectWithBinding()
126126
{
@@ -236,17 +236,17 @@ public void TestMixedQueryTypeWithBinding()
236236
DbDataReader reader = cmd.ExecuteReader();
237237

238238
// result of create
239-
Assert.IsFalse(reader.HasRows);
239+
Assert.IsTrue(reader.HasRows);
240240
Assert.AreEqual(0, reader.RecordsAffected);
241241

242242
// result of insert #1
243243
Assert.IsTrue(reader.NextResult());
244-
Assert.IsFalse(reader.HasRows);
244+
Assert.IsTrue(reader.HasRows);
245245
Assert.AreEqual(1, reader.RecordsAffected);
246246

247247
// result of insert #2
248248
Assert.IsTrue(reader.NextResult());
249-
Assert.IsFalse(reader.HasRows);
249+
Assert.IsTrue(reader.HasRows);
250250
Assert.AreEqual(2, reader.RecordsAffected);
251251

252252
// result of select
@@ -266,7 +266,7 @@ public void TestMixedQueryTypeWithBinding()
266266

267267
// result of drop
268268
Assert.IsTrue(reader.NextResult());
269-
Assert.IsFalse(reader.HasRows);
269+
Assert.IsTrue(reader.HasRows);
270270
Assert.AreEqual(0, reader.RecordsAffected);
271271

272272
Assert.IsFalse(reader.NextResult());
@@ -382,7 +382,7 @@ public void TestWithAllQueryTypes()
382382

383383
// result of create
384384
Assert.IsTrue(reader.NextResult());
385-
Assert.IsFalse(reader.HasRows);
385+
Assert.IsTrue(reader.HasRows);
386386
Assert.AreEqual(0, reader.RecordsAffected);
387387

388388
// result of explain
@@ -400,7 +400,7 @@ public void TestWithAllQueryTypes()
400400

401401
// result of insert
402402
Assert.IsTrue(reader.NextResult());
403-
Assert.IsFalse(reader.HasRows);
403+
Assert.IsTrue(reader.HasRows);
404404
Assert.AreEqual(1, reader.RecordsAffected);
405405

406406
// result of describe
@@ -420,7 +420,7 @@ public void TestWithAllQueryTypes()
420420

421421
// result of create
422422
Assert.IsTrue(reader.NextResult());
423-
Assert.IsFalse(reader.HasRows);
423+
Assert.IsTrue(reader.HasRows);
424424
Assert.AreEqual(0, reader.RecordsAffected);
425425

426426
// result of call
@@ -434,7 +434,7 @@ public void TestWithAllQueryTypes()
434434

435435
// result of use
436436
Assert.IsTrue(reader.NextResult());
437-
Assert.IsFalse(reader.HasRows);
437+
Assert.IsTrue(reader.HasRows);
438438
Assert.AreEqual(0, reader.RecordsAffected);
439439

440440
Assert.IsFalse(reader.NextResult());
@@ -531,5 +531,63 @@ public void TestWithMultipleStatementSetting()
531531
conn.Close();
532532
}
533533
}
534+
535+
[Test]
536+
public void TestResultSetReturnedForAllQueryTypes()
537+
{
538+
using (DbConnection conn = new SnowflakeDbConnection())
539+
{
540+
conn.ConnectionString = ConnectionString;
541+
conn.Open();
542+
543+
using (DbCommand cmd = conn.CreateCommand())
544+
{
545+
cmd.CommandText = "set query_tag = (select 'dummy_tag');" +
546+
"alter session set query_tag='dummy_tag';" +
547+
"select 1;" +
548+
$"create or replace temporary table {TableName}(c1 varchar);" +
549+
$"explain using text select * from {TableName};" +
550+
"show parameters;" +
551+
$"insert into {TableName} values ('str1');" +
552+
$"update {TableName} set c1 = 'str2';" +
553+
$"select * from {TableName};" +
554+
$"desc table {TableName};" +
555+
$"copy into @%{TableName} from {TableName};" +
556+
$"list @%{TableName};" +
557+
$"remove @%{TableName};" +
558+
"create or replace temporary procedure P1() returns varchar language javascript as $$ return ''; $$;" +
559+
"call p1();" +
560+
$"use role {testConfig.role}";
561+
562+
var stmtCount = 16;
563+
564+
// Set statement count
565+
var stmtCountParam = cmd.CreateParameter();
566+
stmtCountParam.ParameterName = "MULTI_STATEMENT_COUNT";
567+
stmtCountParam.DbType = DbType.Int16;
568+
stmtCountParam.Value = stmtCount;
569+
cmd.Parameters.Add(stmtCountParam);
570+
571+
DbDataReader reader = cmd.ExecuteReader();
572+
573+
// at least one row in the first result set
574+
Assert.IsTrue(reader.HasRows);
575+
Assert.IsTrue(reader.Read());
576+
577+
for (int i = 1; i < stmtCount; i++)
578+
{
579+
Assert.IsTrue(reader.NextResult());
580+
581+
// at least one row in subsequent result sets
582+
Assert.IsTrue(reader.HasRows);
583+
Assert.IsTrue(reader.Read());
584+
}
585+
Assert.IsFalse(reader.NextResult());
586+
reader.Close();
587+
}
588+
589+
conn.Close();
590+
}
591+
}
534592
}
535593
}

Snowflake.Data/Client/SnowflakeDbCommand.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ public override int ExecuteNonQuery()
166166
long total = 0;
167167
do
168168
{
169-
if (resultSet.HasResultSet()) continue;
169+
if (resultSet.IsDQL()) continue;
170170
int count = resultSet.CalculateUpdateCount();
171171
if (count < 0)
172172
{
@@ -193,7 +193,7 @@ public override async Task<int> ExecuteNonQueryAsync(CancellationToken cancellat
193193
long total = 0;
194194
do
195195
{
196-
if (resultSet.HasResultSet()) continue;
196+
if (resultSet.IsDQL()) continue;
197197
int count = resultSet.CalculateUpdateCount();
198198
if (count < 0)
199199
{

Snowflake.Data/Client/SnowflakeDbDataReader.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public override bool HasRows
7575
{
7676
get
7777
{
78-
return resultSet.HasResultSet() && resultSet.HasRows();
78+
return !resultSet.isClosed && resultSet.HasRows();
7979
}
8080
}
8181

Snowflake.Data/Core/ResultSetUtil.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ internal static int CalculateUpdateCount(this SFBaseResultSet resultSet)
2626
{
2727
updateCount += resultSet.GetInt64(i);
2828
}
29-
29+
resultSet.Rewind();
3030
break;
3131
case SFStatementType.COPY:
3232
var index = resultSet.sfResultSetMetaData.GetColumnIndexByName("rows_loaded");
@@ -49,6 +49,7 @@ internal static int CalculateUpdateCount(this SFBaseResultSet resultSet)
4949
}
5050
break;
5151
case SFStatementType.SELECT:
52+
// DbDataReader.RecordsAffected returns -1 for SELECT statement
5253
updateCount = -1;
5354
break;
5455
default:
@@ -62,7 +63,7 @@ internal static int CalculateUpdateCount(this SFBaseResultSet resultSet)
6263
return (int)updateCount;
6364
}
6465

65-
internal static bool HasResultSet(this SFBaseResultSet resultSet)
66+
internal static bool IsDQL(this SFBaseResultSet resultSet)
6667
{
6768
if (resultSet.isClosed) return false;
6869

0 commit comments

Comments
 (0)