diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 56b93c9e0..acd254db0 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,6 +9,8 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne # Release Notes - v4.2.0(TBD) - Added support for async I/O. Asynchronous version of connector is available via `snowflake.connector.aio` module. + - Added `SnowflakeCursor.stats` property to expose granular DML statistics (rows inserted, deleted, updated, and duplicates) for operations like CTAS where `rowcount` is insufficient. + - v4.1.1(TBD) - Relaxed pandas dependency requirements for Python below 3.12. - Changed CRL cache cleanup background task to daemon to avoid blocking main thread. diff --git a/src/snowflake/connector/aio/_cursor.py b/src/snowflake/connector/aio/_cursor.py index a0160dc4c..5466d7045 100644 --- a/src/snowflake/connector/aio/_cursor.py +++ b/src/snowflake/connector/aio/_cursor.py @@ -378,6 +378,10 @@ async def _init_result_and_meta(self, data: dict[Any, Any]) -> None: self._rownumber = -1 self._result_state = ResultState.VALID + # Extract stats object if available (for DML operations like CTAS, INSERT, UPDATE, DELETE) + self._stats_data = data.get("stats", None) + logger.debug("Execution DML stats: %s", self.stats) + # don't update the row count when the result is returned from `describe` method if is_dml and "rowset" in data and len(data["rowset"]) > 0: updated_rows = 0 diff --git a/src/snowflake/connector/cursor.py b/src/snowflake/connector/cursor.py index c13ab242c..e408df4a7 100644 --- a/src/snowflake/connector/cursor.py +++ b/src/snowflake/connector/cursor.py @@ -418,6 +418,10 @@ def __init__( self._log_max_query_length = connection.log_max_query_length self._inner_cursor: SnowflakeCursorBase | None = None self._prefetch_hook = None + self._stats_data: dict[str, int] | None = ( + None # Stores stats from response for DML operations + ) + self._rownumber: int | None = None self.reset() @@ -454,6 +458,23 @@ def _description_internal(self) -> list[ResultMetadataV2]: def rowcount(self) -> int | None: return self._total_rowcount if self._total_rowcount >= 0 else None + @property + def stats(self) -> QueryResultStats | None: + """Returns detailed rows affected statistics for DML operations. + + Returns a NamedTuple with fields: + - num_rows_inserted: Number of rows inserted + - num_rows_deleted: Number of rows deleted + - num_rows_updated: Number of rows updated + - num_dml_duplicates: Number of duplicates in DML statement + + Returns None on each position if no DML stats are available - this includes DML operations where no rows were + affected as well as other type of SQL statements (e.g. DDL, DQL). + """ + if self._stats_data is None: + return QueryResultStats(None, None, None, None) + return QueryResultStats.from_dict(self._stats_data) + @property def rownumber(self) -> int | None: return self._rownumber if self._rownumber >= 0 else None @@ -1201,6 +1222,10 @@ def _init_result_and_meta(self, data: dict[Any, Any]) -> None: self._rownumber = -1 self._result_state = ResultState.VALID + # Extract stats object if available (for DML operations like CTAS, INSERT, UPDATE, DELETE) + self._stats_data = data.get("stats", None) + logger.debug("Execution DML stats: %s", self.stats) + # don't update the row count when the result is returned from `describe` method if is_dml and "rowset" in data and len(data["rowset"]) > 0: updated_rows = 0 @@ -2007,3 +2032,26 @@ def __getattr__(name): ) return None raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +class QueryResultStats(NamedTuple): + """ + Statistics for rows affected by a DML operation. + None value expresses particular statistic being unknown - not returned by the backend service. + + Added in the first place to expose DML data of CTAS statements - SNOW-295953 + """ + + num_rows_inserted: int | None = None + num_rows_deleted: int | None = None + num_rows_updated: int | None = None + num_dml_duplicates: int | None = None + + @classmethod + def from_dict(cls, stats_dict: dict[str, int]) -> QueryResultStats: + return cls( + num_rows_inserted=stats_dict.get("numRowsInserted", None), + num_rows_deleted=stats_dict.get("numRowsDeleted", None), + num_rows_updated=stats_dict.get("numRowsUpdated", None), + num_dml_duplicates=stats_dict.get("numDmlDuplicates", None), + ) diff --git a/test/integ/aio_it/test_connection_async.py b/test/integ/aio_it/test_connection_async.py index dd5ef70a5..049ade547 100644 --- a/test/integ/aio_it/test_connection_async.py +++ b/test/integ/aio_it/test_connection_async.py @@ -28,6 +28,8 @@ from snowflake.connector.aio._description import CLIENT_NAME from snowflake.connector.compat import IS_WINDOWS from snowflake.connector.connection import DEFAULT_CLIENT_PREFETCH_THREADS +from snowflake.connector.constants import PARAMETER_MULTI_STATEMENT_COUNT +from snowflake.connector.cursor import QueryResultStats from snowflake.connector.errorcode import ( ER_CONNECTION_IS_CLOSED, ER_FAILED_PROCESSING_PYFORMAT, @@ -1838,3 +1840,682 @@ async def test_no_new_warnings_or_errors_on_successful_basic_select(conn_cnx, ca f"Error count increased from {baseline_error_count} to {test_error_count}. " f"New errors: {[r.getMessage() for r in caplog.records if r.levelno >= logging.ERROR]}" ) + + +@pytest.mark.skipolddriver +async def test_ctas_stats(conn_cnx): + """Test that cursor.rowcount and cursor.stats work for CTAS operations.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + await cur.execute( + "create temp table test_ctas_stats (col1 int) as select col1 from values (1), (2), (3) as t(col1)" + ) + assert ( + cur.rowcount == 1 + ), f"Expected rowcount 1 for CTAS, got {cur.rowcount}" + # stats should contain the details as a NamedTuple + assert cur.stats is not None, "stats should not be None for CTAS" + assert ( + cur.stats.num_rows_inserted == 3 + ), f"Expected num_rows_inserted=3, got {cur.stats.num_rows_inserted}" + assert cur.stats.num_rows_deleted == 0 + assert cur.stats.num_rows_updated == 0 + assert cur.stats.num_dml_duplicates == 0 + + +@pytest.mark.skipolddriver +async def test_create_view_stats(conn_cnx): + """Test that cursor.stats returns None fields for VIEW operations.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + await cur.execute( + "create temp view test_view_stats as select col1 from values (1), (2), (3) as t(col1)" + ) + assert ( + cur.rowcount == 1 + ), f"Expected rowcount 1 for VIEW, got {cur.rowcount}" + # VIEW operations don't return DML stats, all fields should be None + assert cur.stats is not None + assert cur.stats.num_rows_inserted is None + assert cur.stats.num_rows_deleted is None + assert cur.stats.num_rows_updated is None + assert cur.stats.num_dml_duplicates is None + + +@pytest.mark.skipolddriver +async def test_cvas_separate_cursors_stats(conn_cnx): + """Test cursor.stats with CVAS in separate cursor from the one used for CTAS of the table.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + await cur.execute( + "create temp table test_table (col1 int) as select col1 from values (1), (2), (3) as t(col1)" + ) + async with conn.cursor() as cur: + await cur.execute( + "create temp view test_view as select col1 from test_table" + ) + assert ( + cur.rowcount == 1 + ), "Due to old behaviour we should keep rowcount equal to 1 - as the number of rows returned by the backend" + # VIEW operations don't return DML stats + assert cur.stats is not None + assert cur.stats.num_rows_inserted is None + assert cur.stats.num_rows_deleted is None + assert cur.stats.num_rows_updated is None + assert cur.stats.num_dml_duplicates is None + + +@pytest.mark.skipolddriver +async def test_cvas_one_cursor_stats(conn_cnx): + """Test cursor.stats with CVAS in the same cursor - make sure it's cleaned up after usage.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + await cur.execute( + "create temp table test_ctas_stats (col1 int) as select col1 from values (1), (2), (3) as t(col1)" + ) + await cur.execute( + "create temp view test_view as select col1 from test_ctas_stats" + ) + assert ( + cur.rowcount == 1 + ), "Due to old behaviour we should keep rowcount equal to 1 - as the number of rows returned by the backend" + # VIEW operations don't return DML stats + assert cur.stats is not None + assert cur.stats.num_rows_inserted is None + assert cur.stats.num_rows_deleted is None + assert cur.stats.num_rows_updated is None + assert cur.stats.num_dml_duplicates is None + + +def _assert_stats(actual_stats, expected_stats): + """Helper function to assert stats values. + + Args: + actual_stats: The actual QueryResultStats from cursor.stats + expected_stats: Expected QueryResultStats to compare against + """ + assert actual_stats is not None, "stats should not be None" + assert isinstance( + expected_stats, QueryResultStats + ), "expected_stats must be a QueryResultStats instance" + + assert ( + actual_stats.num_rows_inserted == expected_stats.num_rows_inserted + ), f"Expected num_rows_inserted={expected_stats.num_rows_inserted}, got {actual_stats.num_rows_inserted}" + + assert ( + actual_stats.num_rows_deleted == expected_stats.num_rows_deleted + ), f"Expected num_rows_deleted={expected_stats.num_rows_deleted}, got {actual_stats.num_rows_deleted}" + + assert ( + actual_stats.num_rows_updated == expected_stats.num_rows_updated + ), f"Expected num_rows_updated={expected_stats.num_rows_updated}, got {actual_stats.num_rows_updated}" + + assert ( + actual_stats.num_dml_duplicates == expected_stats.num_dml_duplicates + ), f"Expected num_dml_duplicates={expected_stats.num_dml_duplicates}, got {actual_stats.num_dml_duplicates}" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "operation,setup_sql,test_sql,expected_stats", + [ + pytest.param( + "insert_simple", + "create temp table test_stats_table (id int, name varchar(50))", + "insert into test_stats_table values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')", + QueryResultStats( + num_rows_inserted=3, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="insert_simple", + ), + pytest.param( + "update", + """ + create temp table test_stats_table (id int, name varchar(50)); + insert into test_stats_table values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + "update test_stats_table set name = 'Updated' where id in (1, 2)", + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=0, + num_rows_updated=2, + num_dml_duplicates=0, + ), + id="update", + ), + pytest.param( + "delete", + """ + create temp table test_stats_table (id int, name varchar(50)); + insert into test_stats_table values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + "delete from test_stats_table where id in (1, 3)", + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=2, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="delete", + ), + pytest.param( + "merge_insert", + """ + create temp table test_stats_target (id int, name varchar(50)); + insert into test_stats_target values (1, 'Alice'); + """, + """ + merge into test_stats_target t + using (select * from values (2, 'Bob'), (3, 'Charlie') as v(id, name)) s + on t.id = s.id + when not matched then insert (id, name) values (s.id, s.name) + """, + QueryResultStats( + num_rows_inserted=2, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="merge_insert", + ), + pytest.param( + "merge_update", + """ + create temp table test_stats_target (id int, name varchar(50)); + insert into test_stats_target values (1, 'Alice'), (2, 'Bob'); + """, + """ + merge into test_stats_target t + using (select * from values (1, 'Alice Updated'), (2, 'Bob Updated') as v(id, name)) s + on t.id = s.id + when matched then update set t.name = s.name + """, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=0, + num_rows_updated=2, + num_dml_duplicates=0, + ), + id="merge_update", + ), + pytest.param( + "merge_insert_update", + """ + create temp table test_stats_target (id int, name varchar(50)); + insert into test_stats_target values (1, 'Alice'); + """, + """ + merge into test_stats_target t + using (select * from values (1, 'Alice Updated'), (2, 'Bob'), (3, 'Charlie') as v(id, name)) s + on t.id = s.id + when matched then update set t.name = s.name + when not matched then insert (id, name) values (s.id, s.name) + """, + QueryResultStats( + num_rows_inserted=2, + num_rows_deleted=0, + num_rows_updated=1, + num_dml_duplicates=0, + ), + id="merge_insert_update", + ), + pytest.param( + "merge_delete", + """ + create temp table test_stats_target (id int, name varchar(50)); + insert into test_stats_target values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + """ + merge into test_stats_target t + using (select * from values (1, 'Delete Me'), (2, 'Delete Me Too') as v(id, name)) s + on t.id = s.id + when matched then delete + """, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=2, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="merge_delete", + ), + pytest.param( + "merge_insert_update_delete", + """ + create temp table test_stats_target (id int, name varchar(50), status varchar(20)); + insert into test_stats_target values (1, 'Alice', 'active'), (2, 'Bob', 'inactive'), (3, 'Charlie', 'active'); + """, + """ + merge into test_stats_target t + using (select * from values + (1, 'Alice Updated', 'active'), + (2, 'Bob', 'inactive'), + (4, 'David', 'active') + as v(id, name, status)) s + on t.id = s.id + when matched and s.status = 'active' then update set t.name = s.name + when matched and s.status = 'inactive' then delete + when not matched then insert (id, name, status) values (s.id, s.name, s.status) + """, + QueryResultStats( + num_rows_inserted=1, + num_rows_deleted=1, + num_rows_updated=1, + num_dml_duplicates=0, + ), + id="merge_insert_update_delete", + ), + ], +) +async def test_dml_stats_operations( + conn_cnx, operation, setup_sql, test_sql, expected_stats +): + """Test cursor.stats for various DML operations.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # Setup + for sql in setup_sql.strip().split(";"): + sql = sql.strip() + if sql: + await cur.execute(sql) + + # Execute test operation + await cur.execute(test_sql) + + # Verify stats + _assert_stats(cur.stats, expected_stats) + + +@pytest.mark.skipolddriver +async def test_copy_into_stats_with_stage(conn_cnx, tmp_path): + """Test cursor.stats for COPY INTO operations with actual stage files.""" + import csv + + # Create a CSV file + csv_file = tmp_path / "test_data.csv" + with open(csv_file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow([1, "Alice"]) + writer.writerow([2, "Bob"]) + writer.writerow([3, "Charlie"]) + + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # Create table and stage + await cur.execute( + "create temp table test_copy_stats (id int, name varchar(50))" + ) + await cur.execute("create temp stage test_copy_stage") + + # PUT file to stage + await cur.execute(f"put file://{csv_file} @test_copy_stage") + + # COPY INTO from stage + await cur.execute( + """ + copy into test_copy_stats + from @test_copy_stage/test_data.csv.gz + file_format = (type = csv) + """ + ) + + # Verify stats + assert cur.stats is not None + assert cur.stats.num_rows_inserted == 3 + assert cur.stats.num_rows_deleted == 0 + assert cur.stats.num_rows_updated == 0 + assert cur.stats.num_dml_duplicates == 0 + + +@pytest.mark.skipolddriver +async def test_update_with_dml_duplicates(conn_cnx): + """Test cursor.stats for UPDATE operations that generate numDmlDuplicates. + + When a row in the updated table is matched by multiple rows in the FROM clause, + Snowflake reports the extra matches as duplicates in numDmlDuplicates. + """ + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # test_src has 15 rows: five 0's, five 1's, five 2's + await cur.execute("create temp table test_src (c1 int, c2 int)") + await cur.execute( + "insert into test_src values (0, 100), (1, 100), (2, 100), (0, 100), (1, 100), " + "(2, 100), (0, 100), (1, 100), (2, 100), (0, 100), (1, 100), (2, 100), " + "(0, 100), (1, 100), (2, 100)" + ) + + # test_target has 4 rows: two 0's, one 1, one 2 + await cur.execute("create temp table test_target (c int)") + await cur.execute("insert into test_target values (0), (1), (2), (0)") + + # UPDATE with FROM clause: + # - Each of 5 rows with c1=0 matches 2 rows in test_target (duplicate count: 5 × 1 = 5) + # - Each of 5 rows with c1=1 matches 1 row in test_target (duplicate count: 0) + # - Each of 5 rows with c1=2 matches 1 row in test_target (duplicate count: 0) + # Total duplicates: 5 + await cur.execute( + """ + update test_src set c2 = test_target.c + from test_target + where test_src.c1 = test_target.c + """ + ) + + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=0, + num_rows_updated=15, + num_dml_duplicates=5, + ), + ) + + +@pytest.mark.skipolddriver +async def test_multi_table_insert_overwrite_stats(conn_cnx): + """Test cursor.stats for multi-table INSERT OVERWRITE operations.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # Source has 3 values: 5, 15, 25 + await cur.execute("create temp table test_src_multi (c1 int)") + await cur.execute("insert into test_src_multi values (5), (15), (25)") + + # Target tables with existing data + await cur.execute("create temp table test_tgt1 (c int)") + await cur.execute("create temp table test_tgt2 (c int)") + await cur.execute("insert into test_tgt1 values (100), (101)") + await cur.execute("insert into test_tgt2 values (200), (201), (202)") + + # INSERT OVERWRITE ALL evaluates ALL matching WHEN clauses per row: + # - c1=5: no WHENs match → else clause → 1 insert (5 to tgt2) + # - c1=15: second WHEN matches → 2 inserts (15 to tgt1, 15 to tgt2) + # - c1=25: both WHENs match → 3 inserts (25 to tgt1, then 25 to tgt1 and 25 to tgt2) + # Result: tgt1=[25,15,25], tgt2=[15,25,5] + # Total: 6 inserts, 5 deletes (2+3 existing rows cleared by OVERWRITE) + await cur.execute( + """ + insert overwrite all + when c1 > 20 then + into test_tgt1 values (c1) + when c1 > 10 then + into test_tgt1 values (c1) + into test_tgt2 values (c1) + else + into test_tgt2 values (c1) + select c1 from test_src_multi + """ + ) + + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=6, + num_rows_deleted=5, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + +@pytest.mark.xfail(reason="Multi-statements does not return stats field") +@pytest.mark.skipolddriver +async def test_multi_statement_in_one_execute(conn_cnx): + """Test that stats reflect the last statement when multiple statements are in one execute.""" + async with conn_cnx( + session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0} + ) as conn: + async with conn.cursor() as cur: + # Execute multiple statements separated by semicolons in one execute call + # Stats should reflect ONLY the last statement + await cur.execute( + """ + create temp table test_multiexec (id int, name varchar(50)); + insert into test_multiexec values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + update test_multiexec set name = 'Updated' where id = 1; + delete from test_multiexec where id = 2; + """ + ) + + # Stats reflect only the last statement (DELETE of 1 row) + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=1, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + +@pytest.mark.skipolddriver +async def test_stats_reset_on_select(conn_cnx): + """Test that stats are reset to None when executing SELECT queries.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # Setup and insert + await cur.execute("create temp table test_stats_reset (id int)") + await cur.execute("insert into test_stats_reset values (1), (2), (3)") + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=3, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + # Execute a SELECT - stats should have all None values + await cur.execute("select * from test_stats_reset") + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=None, + num_rows_deleted=None, + num_rows_updated=None, + num_dml_duplicates=None, + ), + ) + + +@pytest.mark.skipolddriver +async def test_truncate_stats(conn_cnx): + """Test cursor.stats for TRUNCATE operations.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # Setup + await cur.execute("create temp table test_truncate_stats (id int)") + await cur.execute("insert into test_truncate_stats values (1), (2), (3)") + + # Truncate doesn't provide detailed stats + await cur.execute("truncate table test_truncate_stats") + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=3, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + +@pytest.mark.skipolddriver +async def test_empty_result_stats(conn_cnx): + """Test cursor.stats for DML operations that affect zero rows.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # Setup + await cur.execute( + "create temp table test_empty_stats (id int, name varchar(50))" + ) + await cur.execute("insert into test_empty_stats values (1, 'Alice')") + + # Update with no matching rows + await cur.execute( + "update test_empty_stats set name = 'Updated' where id = 999" + ) + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=None, + num_rows_deleted=None, + num_rows_updated=None, + num_dml_duplicates=None, + ), + ) + + # Delete with no matching rows + await cur.execute("delete from test_empty_stats where id = 999") + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=None, + num_rows_deleted=None, + num_rows_updated=None, + num_dml_duplicates=None, + ), + ) + + +@pytest.mark.skipolddriver +@pytest.mark.xfail( + reason="execute_async stats are not returned from monitoring endpoint yet" +) +@pytest.mark.parametrize( + "operation,setup_sql,test_sql,expected_stats", + [ + pytest.param( + "insert_async", + "create temp table test_async_stats (id int, name varchar(50))", + "insert into test_async_stats values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')", + QueryResultStats( + num_rows_inserted=3, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="insert_async", + ), + pytest.param( + "update_async", + """ + create temp table test_async_stats (id int, name varchar(50)); + insert into test_async_stats values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + "update test_async_stats set name = 'Updated' where id in (1, 2)", + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=0, + num_rows_updated=2, + num_dml_duplicates=0, + ), + id="update_async", + ), + pytest.param( + "delete_async", + """ + create temp table test_async_stats (id int, name varchar(50)); + insert into test_async_stats values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + "delete from test_async_stats where id in (1, 3)", + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=2, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="delete_async", + ), + pytest.param( + "ctas_async", + "", + "create temp table test_async_ctas (id int) as select * from values (1), (2), (3), (4), (5) as t(id)", + QueryResultStats( + num_rows_inserted=5, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="ctas_async", + ), + ], +) +async def test_execute_async_stats( + conn_cnx, operation, setup_sql, test_sql, expected_stats +): + """Test cursor.stats for DML operations executed asynchronously.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # Setup + if setup_sql: + for sql in setup_sql.strip().split(";"): + sql = sql.strip() + if sql: + await cur.execute(sql) + + # Execute async + await cur.execute_async(test_sql) + query_id = cur.sfqid + + # Get results + await cur.get_results_from_sfqid(query_id) + + # Verify stats are available after getting results + _assert_stats(cur.stats, expected_stats) + + +@pytest.mark.skipolddriver +@pytest.mark.xfail( + reason="execute_async stats are not returned from monitoring endpoint yet" +) +async def test_execute_async_stats_multiple_queries(conn_cnx): + """Test cursor.stats with multiple async queries.""" + async with conn_cnx() as conn: + async with conn.cursor() as cur: + # Setup + await cur.execute( + "create temp table test_multi_async (id int, name varchar(50))" + ) + + # Execute first async query + await cur.execute_async( + "insert into test_multi_async values (1, 'Alice'), (2, 'Bob')" + ) + qid1 = cur.sfqid + + # Execute second async query + await cur.execute_async( + "insert into test_multi_async values (3, 'Charlie')" + ) + qid2 = cur.sfqid + + # Get results for first query + await cur.get_results_from_sfqid(qid1) + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=2, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + # Get results for second query + await cur.get_results_from_sfqid(qid2) + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=1, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) diff --git a/test/integ/test_connection.py b/test/integ/test_connection.py index b5d490d34..7e61272a8 100644 --- a/test/integ/test_connection.py +++ b/test/integ/test_connection.py @@ -24,6 +24,8 @@ DEFAULT_CLIENT_PREFETCH_THREADS, SnowflakeConnection, ) +from snowflake.connector.constants import PARAMETER_MULTI_STATEMENT_COUNT +from snowflake.connector.cursor import QueryResultStats from snowflake.connector.description import CLIENT_NAME from snowflake.connector.errorcode import ( ER_CONNECTION_IS_CLOSED, @@ -1067,7 +1069,7 @@ def test_client_fetch_threads_setting(conn_cnx): @pytest.mark.skipolddriver @pytest.mark.parametrize("disable_request_pooling", [True, False]) def test_ocsp_and_rest_pool_isolation(conn_cnx, disable_request_pooling): - """Each connection’s SessionManager is isolated; OCSP picks the right one.""" + """Each connection's SessionManager is isolated; OCSP picks the right one.""" from snowflake.connector.ssl_wrap_socket import get_current_session_manager # @@ -1896,3 +1898,665 @@ def test_snowflake_version(): assert re.match( version_pattern, conn.snowflake_version ), f"snowflake_version should match pattern 'x.y.z', but got '{conn.snowflake_version}'" + + +@pytest.mark.skipolddriver +def test_ctas_stats(conn_cnx): + """Test that cursor.rowcount and cursor.stats work for CTAS operations.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + cur.execute( + "create temp table test_ctas_stats (col1 int) as select col1 from values (1), (2), (3) as t(col1)" + ) + assert ( + cur.rowcount == 1 + ), f"Expected rowcount 1 for CTAS, got {cur.rowcount}" + # stats should contain the details as a NamedTuple + assert cur.stats is not None, "stats should not be None for CTAS" + assert ( + cur.stats.num_rows_inserted == 3 + ), f"Expected num_rows_inserted=3, got {cur.stats.num_rows_inserted}" + assert cur.stats.num_rows_deleted == 0 + assert cur.stats.num_rows_updated == 0 + assert cur.stats.num_dml_duplicates == 0 + + +@pytest.mark.skipolddriver +def test_create_view_stats(conn_cnx): + """Test that cursor.stats returns None fields for VIEW operations.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + cur.execute( + "create temp view test_view_stats as select col1 from values (1), (2), (3) as t(col1)" + ) + assert ( + cur.rowcount == 1 + ), f"Expected rowcount 1 for VIEW, got {cur.rowcount}" + # VIEW operations don't return DML stats, all fields should be None + assert cur.stats is not None + assert cur.stats.num_rows_inserted is None + assert cur.stats.num_rows_deleted is None + assert cur.stats.num_rows_updated is None + assert cur.stats.num_dml_duplicates is None + + +@pytest.mark.skipolddriver +def test_cvas_separate_cursors_stats(conn_cnx): + """Test cursor.stats with CVAS in separate cursor from the one used for CTAS of the table.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + cur.execute( + "create temp table test_table (col1 int) as select col1 from values (1), (2), (3) as t(col1)" + ) + with conn.cursor() as cur: + cur.execute("create temp view test_view as select col1 from test_table") + assert ( + cur.rowcount == 1 + ), "Due to old behaviour we should keep rowcount equal to 1 - as the number of rows returned by the backend" + # VIEW operations don't return DML stats + assert cur.stats is not None + assert cur.stats.num_rows_inserted is None + assert cur.stats.num_rows_deleted is None + assert cur.stats.num_rows_updated is None + assert cur.stats.num_dml_duplicates is None + + +@pytest.mark.skipolddriver +def test_cvas_one_cursor_stats(conn_cnx): + """Test cursor.stats with CVAS in the same cursor - make sure it's cleaned up after usage.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + cur.execute( + "create temp table test_ctas_stats (col1 int) as select col1 from values (1), (2), (3) as t(col1)" + ) + cur.execute( + "create temp view test_view as select col1 from test_ctas_stats" + ) + assert ( + cur.rowcount == 1 + ), "Due to old behaviour we should keep rowcount equal to 1 - as the number of rows returned by the backend" + # VIEW operations don't return DML stats + assert cur.stats is not None + assert cur.stats.num_rows_inserted is None + assert cur.stats.num_rows_deleted is None + assert cur.stats.num_rows_updated is None + assert cur.stats.num_dml_duplicates is None + + +def _assert_stats(actual_stats, expected_stats): + """Helper function to assert stats values. + + Args: + actual_stats: The actual QueryResultStats from cursor.stats + expected_stats: Expected QueryResultStats to compare against + """ + assert actual_stats is not None, "stats should not be None" + assert isinstance( + expected_stats, QueryResultStats + ), "expected_stats must be a QueryResultStats instance" + + assert ( + actual_stats.num_rows_inserted == expected_stats.num_rows_inserted + ), f"Expected num_rows_inserted={expected_stats.num_rows_inserted}, got {actual_stats.num_rows_inserted}" + + assert ( + actual_stats.num_rows_deleted == expected_stats.num_rows_deleted + ), f"Expected num_rows_deleted={expected_stats.num_rows_deleted}, got {actual_stats.num_rows_deleted}" + + assert ( + actual_stats.num_rows_updated == expected_stats.num_rows_updated + ), f"Expected num_rows_updated={expected_stats.num_rows_updated}, got {actual_stats.num_rows_updated}" + + assert ( + actual_stats.num_dml_duplicates == expected_stats.num_dml_duplicates + ), f"Expected num_dml_duplicates={expected_stats.num_dml_duplicates}, got {actual_stats.num_dml_duplicates}" + + +@pytest.mark.skipolddriver +@pytest.mark.parametrize( + "operation,setup_sql,test_sql,expected_stats", + [ + pytest.param( + "insert_simple", + "create temp table test_stats_table (id int, name varchar(50))", + "insert into test_stats_table values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')", + QueryResultStats( + num_rows_inserted=3, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="insert_simple", + ), + pytest.param( + "update", + """ + create temp table test_stats_table (id int, name varchar(50)); + insert into test_stats_table values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + "update test_stats_table set name = 'Updated' where id in (1, 2)", + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=0, + num_rows_updated=2, + num_dml_duplicates=0, + ), + id="update", + ), + pytest.param( + "delete", + """ + create temp table test_stats_table (id int, name varchar(50)); + insert into test_stats_table values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + "delete from test_stats_table where id in (1, 3)", + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=2, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="delete", + ), + pytest.param( + "merge_insert", + """ + create temp table test_stats_target (id int, name varchar(50)); + insert into test_stats_target values (1, 'Alice'); + """, + """ + merge into test_stats_target t + using (select * from values (2, 'Bob'), (3, 'Charlie') as v(id, name)) s + on t.id = s.id + when not matched then insert (id, name) values (s.id, s.name) + """, + QueryResultStats( + num_rows_inserted=2, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="merge_insert", + ), + pytest.param( + "merge_update", + """ + create temp table test_stats_target (id int, name varchar(50)); + insert into test_stats_target values (1, 'Alice'), (2, 'Bob'); + """, + """ + merge into test_stats_target t + using (select * from values (1, 'Alice Updated'), (2, 'Bob Updated') as v(id, name)) s + on t.id = s.id + when matched then update set t.name = s.name + """, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=0, + num_rows_updated=2, + num_dml_duplicates=0, + ), + id="merge_update", + ), + pytest.param( + "merge_insert_update", + """ + create temp table test_stats_target (id int, name varchar(50)); + insert into test_stats_target values (1, 'Alice'); + """, + """ + merge into test_stats_target t + using (select * from values (1, 'Alice Updated'), (2, 'Bob'), (3, 'Charlie') as v(id, name)) s + on t.id = s.id + when matched then update set t.name = s.name + when not matched then insert (id, name) values (s.id, s.name) + """, + QueryResultStats( + num_rows_inserted=2, + num_rows_deleted=0, + num_rows_updated=1, + num_dml_duplicates=0, + ), + id="merge_insert_update", + ), + pytest.param( + "merge_delete", + """ + create temp table test_stats_target (id int, name varchar(50)); + insert into test_stats_target values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + """ + merge into test_stats_target t + using (select * from values (1, 'Delete Me'), (2, 'Delete Me Too') as v(id, name)) s + on t.id = s.id + when matched then delete + """, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=2, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="merge_delete", + ), + pytest.param( + "merge_insert_update_delete", + """ + create temp table test_stats_target (id int, name varchar(50), status varchar(20)); + insert into test_stats_target values (1, 'Alice', 'active'), (2, 'Bob', 'inactive'), (3, 'Charlie', 'active'); + """, + """ + merge into test_stats_target t + using (select * from values + (1, 'Alice Updated', 'active'), + (2, 'Bob', 'inactive'), + (4, 'David', 'active') + as v(id, name, status)) s + on t.id = s.id + when matched and s.status = 'active' then update set t.name = s.name + when matched and s.status = 'inactive' then delete + when not matched then insert (id, name, status) values (s.id, s.name, s.status) + """, + QueryResultStats( + num_rows_inserted=1, + num_rows_deleted=1, + num_rows_updated=1, + num_dml_duplicates=0, + ), + id="merge_insert_update_delete", + ), + ], +) +def test_dml_stats_operations(conn_cnx, operation, setup_sql, test_sql, expected_stats): + """Test cursor.stats for various DML operations.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + # Setup + for sql in setup_sql.strip().split(";"): + sql = sql.strip() + if sql: + cur.execute(sql) + + # Execute test operation + cur.execute(test_sql) + + # Verify stats + _assert_stats(cur.stats, expected_stats) + + +@pytest.mark.skipolddriver +def test_copy_into_stats_with_stage(conn_cnx, tmp_path): + """Test cursor.stats for COPY INTO operations with actual stage files.""" + import csv + + # Create a CSV file + csv_file = tmp_path / "test_data.csv" + with open(csv_file, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow([1, "Alice"]) + writer.writerow([2, "Bob"]) + writer.writerow([3, "Charlie"]) + + with conn_cnx() as conn: + with conn.cursor() as cur: + # Create table and stage + cur.execute("create temp table test_copy_stats (id int, name varchar(50))") + cur.execute("create temp stage test_copy_stage") + + # PUT file to stage + cur.execute(f"put file://{csv_file} @test_copy_stage") + + # COPY INTO from stage + cur.execute( + """ + copy into test_copy_stats + from @test_copy_stage/test_data.csv.gz + file_format = (type = csv) + """ + ) + + # Verify stats + assert cur.stats is not None + assert cur.stats.num_rows_inserted == 3 + assert cur.stats.num_rows_deleted == 0 + assert cur.stats.num_rows_updated == 0 + assert cur.stats.num_dml_duplicates == 0 + + +@pytest.mark.skipolddriver +def test_update_with_dml_duplicates(conn_cnx): + """Test cursor.stats for UPDATE operations that generate numDmlDuplicates. + + When a row in the updated table is matched by multiple rows in the FROM clause, + Snowflake reports the extra matches as duplicates in numDmlDuplicates. + """ + with conn_cnx() as conn: + with conn.cursor() as cur: + # test_src has 15 rows: five 0's, five 1's, five 2's + cur.execute("create temp table test_src (c1 int, c2 int)") + cur.execute( + "insert into test_src values (0, 100), (1, 100), (2, 100), (0, 100), (1, 100), " + "(2, 100), (0, 100), (1, 100), (2, 100), (0, 100), (1, 100), (2, 100), " + "(0, 100), (1, 100), (2, 100)" + ) + + # test_target has 4 rows: two 0's, one 1, one 2 + cur.execute("create temp table test_target (c int)") + cur.execute("insert into test_target values (0), (1), (2), (0)") + + # UPDATE with FROM clause: + # - Each of 5 rows with c1=0 matches 2 rows in test_target (duplicate count: 5 × 1 = 5) + # - Each of 5 rows with c1=1 matches 1 row in test_target (duplicate count: 0) + # - Each of 5 rows with c1=2 matches 1 row in test_target (duplicate count: 0) + # Total duplicates: 5 + cur.execute( + """ + update test_src set c2 = test_target.c + from test_target + where test_src.c1 = test_target.c + """ + ) + + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=0, + num_rows_updated=15, + num_dml_duplicates=5, + ), + ) + + +@pytest.mark.skipolddriver +def test_multi_table_insert_overwrite_stats(conn_cnx): + """Test cursor.stats for multi-table INSERT OVERWRITE operations.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + # Source has 3 values: 5, 15, 25 + cur.execute("create temp table test_src_multi (c1 int)") + cur.execute("insert into test_src_multi values (5), (15), (25)") + + # Target tables with existing data + cur.execute("create temp table test_tgt1 (c int)") + cur.execute("create temp table test_tgt2 (c int)") + cur.execute("insert into test_tgt1 values (100), (101)") + cur.execute("insert into test_tgt2 values (200), (201), (202)") + + # INSERT OVERWRITE ALL evaluates ALL matching WHEN clauses per row: + # - c1=5: no WHENs match → else clause → 1 insert (5 to tgt2) + # - c1=15: second WHEN matches → 2 inserts (15 to tgt1, 15 to tgt2) + # - c1=25: both WHENs match → 3 inserts (25 to tgt1, then 25 to tgt1 and 25 to tgt2) + # Result: tgt1=[25,15,25], tgt2=[15,25,5] + # Total: 6 inserts, 5 deletes (2+3 existing rows cleared by OVERWRITE) + cur.execute( + """ + insert overwrite all + when c1 > 20 then + into test_tgt1 values (c1) + when c1 > 10 then + into test_tgt1 values (c1) + into test_tgt2 values (c1) + else + into test_tgt2 values (c1) + select c1 from test_src_multi + """ + ) + + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=6, + num_rows_deleted=5, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + +@pytest.mark.xfail(reason="Multi-statements does not return stats field") +@pytest.mark.skipolddriver +def test_multi_statement_in_one_execute(conn_cnx): + """Test that stats reflect the last statement when multiple statements are in one execute.""" + with conn_cnx(session_parameters={PARAMETER_MULTI_STATEMENT_COUNT: 0}) as conn: + with conn.cursor() as cur: + # Execute multiple statements separated by semicolons in one execute call + # The stats should reflect ONLY the last statement + cur.execute( + """ + create temp table test_multiexec (id int, name varchar(50)); + insert into test_multiexec values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + update test_multiexec set name = 'Updated' where id = 1; + delete from test_multiexec where id = 2; + """ + ) + + # Stats reflect only the last statement (DELETE of 1 row) + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=1, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + +@pytest.mark.skipolddriver +@pytest.mark.xfail( + reason="execute_async stats are not returned from monitoring endpoint yet" +) +@pytest.mark.parametrize( + "operation,setup_sql,test_sql,expected_stats", + [ + pytest.param( + "insert_async", + "create temp table test_async_stats (id int, name varchar(50))", + "insert into test_async_stats values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie')", + QueryResultStats( + num_rows_inserted=3, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="insert_async", + ), + pytest.param( + "update_async", + """ + create temp table test_async_stats (id int, name varchar(50)); + insert into test_async_stats values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + "update test_async_stats set name = 'Updated' where id in (1, 2)", + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=0, + num_rows_updated=2, + num_dml_duplicates=0, + ), + id="update_async", + ), + pytest.param( + "delete_async", + """ + create temp table test_async_stats (id int, name varchar(50)); + insert into test_async_stats values (1, 'Alice'), (2, 'Bob'), (3, 'Charlie'); + """, + "delete from test_async_stats where id in (1, 3)", + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=2, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="delete_async", + ), + pytest.param( + "ctas_async", + "", + "create temp table test_async_ctas (id int) as select * from values (1), (2), (3), (4), (5) as t(id)", + QueryResultStats( + num_rows_inserted=5, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + id="ctas_async", + ), + ], +) +def test_execute_async_stats(conn_cnx, operation, setup_sql, test_sql, expected_stats): + """Test cursor.stats for DML operations executed asynchronously.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + # Setup + if setup_sql: + for sql in setup_sql.strip().split(";"): + sql = sql.strip() + if sql: + cur.execute(sql) + + # Execute async + cur.execute_async(test_sql) + query_id = cur.sfqid + + # Get results + cur.get_results_from_sfqid(query_id) + + # Verify stats are available after getting results + _assert_stats(cur.stats, expected_stats) + + +@pytest.mark.skipolddriver +@pytest.mark.xfail( + reason="execute_async stats are not returned from monitoring endpoint yet" +) +def test_execute_async_stats_multiple_queries(conn_cnx): + """Test cursor.stats with multiple async queries.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + # Setup + cur.execute("create temp table test_multi_async (id int, name varchar(50))") + + # Execute first async query + cur.execute_async( + "insert into test_multi_async values (1, 'Alice'), (2, 'Bob')" + ) + qid1 = cur.sfqid + + # Execute second async query + cur.execute_async("insert into test_multi_async values (3, 'Charlie')") + qid2 = cur.sfqid + + # Get results for first query + cur.get_results_from_sfqid(qid1) + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=2, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + # Get results for second query + cur.get_results_from_sfqid(qid2) + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=1, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + +@pytest.mark.skipolddriver +def test_stats_reset_on_select(conn_cnx): + """Test that stats are reset to None when executing SELECT queries.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + # Setup and insert + cur.execute("create temp table test_stats_reset (id int)") + cur.execute("insert into test_stats_reset values (1), (2), (3)") + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=3, + num_rows_deleted=0, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + # Execute a SELECT - stats should have all None values + cur.execute("select * from test_stats_reset") + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=None, + num_rows_deleted=None, + num_rows_updated=None, + num_dml_duplicates=None, + ), + ) + + +@pytest.mark.skipolddriver +def test_truncate_stats(conn_cnx): + """Test cursor.stats for TRUNCATE operations.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + # Setup + cur.execute("create temp table test_truncate_stats (id int)") + cur.execute("insert into test_truncate_stats values (1), (2), (3)") + + # Truncate doesn't provide detailed stats + cur.execute("truncate table test_truncate_stats") + # Truncate typically doesn't populate DML stats + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=0, + num_rows_deleted=3, + num_rows_updated=0, + num_dml_duplicates=0, + ), + ) + + +@pytest.mark.skipolddriver +def test_empty_result_stats(conn_cnx): + """Test cursor.stats for DML operations that affect zero rows.""" + with conn_cnx() as conn: + with conn.cursor() as cur: + # Setup + cur.execute("create temp table test_empty_stats (id int, name varchar(50))") + cur.execute("insert into test_empty_stats values (1, 'Alice')") + + # Update with no matching rows + cur.execute("update test_empty_stats set name = 'Updated' where id = 999") + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=None, + num_rows_deleted=None, + num_rows_updated=None, + num_dml_duplicates=None, + ), + ) + + # Delete with no matching rows + cur.execute("delete from test_empty_stats where id = 999") + _assert_stats( + cur.stats, + QueryResultStats( + num_rows_inserted=None, + num_rows_deleted=None, + num_rows_updated=None, + num_dml_duplicates=None, + ), + )