Skip to content

Commit f3871ab

Browse files
SNOW-2105991: Use pre-computed row counts more aggressively (#3358)
Co-authored-by: Mahesh Vashishtha <mahesh.vashishtha@snowflake.com>
1 parent 44a69d4 commit f3871ab

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+323
-351
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#### Improvements
88

99
- Added support for reading XML files with namespaces using `rowTag` and `stripNamespaces` options.
10+
- Added a new argument to `Dataframe.describe` called `strings_include_math_stats` that triggers `stddev` and `mean` to be calculated for String columns.
1011

1112
### Snowpark Local Testing Updates
1213

@@ -20,7 +21,7 @@
2021

2122
- Set the default value of the `index` parameter to `False` for `DataFrame.to_view`, `Series.to_view`, `DataFrame.to_dynamic_table`, and `Series.to_dynamic_table`.
2223
- Added `iceberg_version` option to table creation functions.
23-
- Added a new argument to `Dataframe.describe` called `strings_include_math_stats` that triggers `stddev` and `mean` to be calculated for String columns.
24+
- Reduced query count for many operations, including `insert`, `repr`, and `groupby`, that previously issued a query to retrieve the input data's size.
2425

2526
## 1.32.0 (2025-05-15)
2627

src/snowflake/snowpark/modin/plugin/_internal/frame.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -738,10 +738,7 @@ def num_rows(self) -> int:
738738
Returns:
739739
Number of rows in this frame.
740740
"""
741-
num_rows = count_rows(self.ordered_dataframe)
742-
self.ordered_dataframe.row_count = num_rows
743-
self.ordered_dataframe.row_count_upper_bound = num_rows
744-
return num_rows
741+
return count_rows(self.ordered_dataframe)
745742

746743
def has_unique_index(self, axis: Optional[int] = 0) -> bool:
747744
"""

src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -680,6 +680,7 @@ def select(
680680
row_count_snowflake_quoted_identifier=self.row_count_snowflake_quoted_identifier,
681681
)
682682

683+
new_df.row_count = self.row_count
683684
# Update the row count upper bound
684685
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
685686
self, DataFrameOperation.SELECT, args={}
@@ -746,6 +747,8 @@ def union_all(self, other: "OrderedDataFrame") -> "OrderedDataFrame":
746747
DataFrameReference(snowpark_dataframe, result_column_quoted_identifiers),
747748
projected_column_snowflake_quoted_identifiers=result_column_quoted_identifiers,
748749
)
750+
if self.row_count is not None and other.row_count is not None:
751+
new_df.row_count = self.row_count + other.row_count
749752
# Update the row count upper bound
750753
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
751754
self, DataFrameOperation.UNION_ALL, args={"other": other}
@@ -849,6 +852,7 @@ def sort(
849852
# No need to reset row count, since sorting should not add/drop rows.
850853
row_count_snowflake_quoted_identifier=self.row_count_snowflake_quoted_identifier,
851854
)
855+
new_df.row_count = self.row_count
852856
# Update the row count upper bound
853857
new_df.row_count_upper_bound = RowCountEstimator.upper_bound(
854858
self, DataFrameOperation.SORT, args={}

src/snowflake/snowpark/modin/plugin/_internal/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1889,9 +1889,14 @@ def count_rows(df: OrderedDataFrame) -> int:
18891889
"""
18901890
Returns the number of rows of a Snowpark DataFrame.
18911891
"""
1892+
if df.row_count is not None:
1893+
return df.row_count
18921894
df = df.ensure_row_count_column()
18931895
rowset = df.select(df.row_count_snowflake_quoted_identifier).limit(1).collect()
1894-
return 0 if len(rowset) == 0 else rowset[0][0]
1896+
row_count = 0 if len(rowset) == 0 else rowset[0][0]
1897+
df.row_count = row_count
1898+
df.row_count_upper_bound = row_count
1899+
return row_count
18951900

18961901

18971902
def append_columns(

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 40 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13216,62 +13216,68 @@ def build_repr_df(
1321613216
# 2. retrieve all columns
1321713217
# 3. filter on rows with recursive count
1321813218

13219-
# Previously, 2 queries were issued, and a first version replaced them with a single query and a join
13220-
# the solution here uses a window function. This may lead to perf regressions, track these here SNOW-984177.
13221-
# Ensure that our reference to self._modin_frame is updated with cached row count and position.
13222-
self._modin_frame = (
13223-
self._modin_frame.ensure_row_position_column().ensure_row_count_column()
13224-
)
13225-
row_count_pandas_label = (
13226-
ROW_COUNT_COLUMN_LABEL
13227-
if len(self._modin_frame.data_column_pandas_index_names) == 1
13228-
else (ROW_COUNT_COLUMN_LABEL,)
13229-
* len(self._modin_frame.data_column_pandas_index_names)
13230-
)
13231-
frame_with_row_count_and_position = InternalFrame.create(
13232-
ordered_dataframe=self._modin_frame.ordered_dataframe,
13233-
data_column_pandas_labels=self._modin_frame.data_column_pandas_labels
13234-
+ [row_count_pandas_label],
13235-
data_column_snowflake_quoted_identifiers=self._modin_frame.data_column_snowflake_quoted_identifiers
13236-
+ [self._modin_frame.row_count_snowflake_quoted_identifier],
13237-
data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names,
13238-
index_column_pandas_labels=self._modin_frame.index_column_pandas_labels,
13239-
index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers,
13240-
data_column_types=self._modin_frame.cached_data_column_snowpark_pandas_types
13241-
+ [None],
13242-
index_column_types=self._modin_frame.cached_index_column_snowpark_pandas_types,
13243-
)
13219+
frame = self._modin_frame.ensure_row_position_column()
13220+
use_cached_row_count = frame.ordered_dataframe.row_count is not None
1324413221

13245-
row_count_identifier = (
13246-
frame_with_row_count_and_position.row_count_snowflake_quoted_identifier
13247-
)
13222+
# If the row count is already cached, there's no need to include it in the query.
13223+
if use_cached_row_count:
13224+
row_count_expr = pandas_lit(frame.ordered_dataframe.row_count)
13225+
else:
13226+
# Previously, 2 queries were issued, and a first version replaced them with a single query and a join
13227+
# the solution here uses a window function. This may lead to perf regressions, track these here SNOW-984177.
13228+
# Ensure that our reference to self._modin_frame is updated with cached row count and position.
13229+
frame = frame.ensure_row_count_column()
13230+
row_count_pandas_label = (
13231+
ROW_COUNT_COLUMN_LABEL
13232+
if len(frame.data_column_pandas_index_names) == 1
13233+
else (ROW_COUNT_COLUMN_LABEL,)
13234+
* len(frame.data_column_pandas_index_names)
13235+
)
13236+
frame = InternalFrame.create(
13237+
ordered_dataframe=frame.ordered_dataframe,
13238+
data_column_pandas_labels=frame.data_column_pandas_labels
13239+
+ [row_count_pandas_label],
13240+
data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers
13241+
+ [frame.row_count_snowflake_quoted_identifier],
13242+
data_column_pandas_index_names=frame.data_column_pandas_index_names,
13243+
index_column_pandas_labels=frame.index_column_pandas_labels,
13244+
index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers,
13245+
data_column_types=frame.cached_data_column_snowpark_pandas_types
13246+
+ [None],
13247+
index_column_types=frame.cached_index_column_snowpark_pandas_types,
13248+
)
13249+
13250+
row_count_expr = col(frame.row_count_snowflake_quoted_identifier)
1324813251
row_position_snowflake_quoted_identifier = (
13249-
frame_with_row_count_and_position.row_position_snowflake_quoted_identifier
13252+
frame.row_position_snowflake_quoted_identifier
1325013253
)
1325113254

1325213255
# filter frame based on num_rows.
1325313256
# always return all columns as this may also result in a query.
1325413257
# in the future could analyze plan to see whether retrieving column count would trigger a query, if not
1325513258
# simply filter out based on static schema
1325613259
num_rows_for_head_and_tail = num_rows_to_display // 2 + 1
13257-
new_frame = frame_with_row_count_and_position.filter(
13260+
new_frame = frame.filter(
1325813261
(
1325913262
col(row_position_snowflake_quoted_identifier)
1326013263
<= num_rows_for_head_and_tail
1326113264
)
1326213265
| (
1326313266
col(row_position_snowflake_quoted_identifier)
13264-
>= col(row_count_identifier) - num_rows_for_head_and_tail
13267+
>= row_count_expr - num_rows_for_head_and_tail
1326513268
)
1326613269
)
1326713270

1326813271
# retrieve frame as pandas object
1326913272
new_qc = SnowflakeQueryCompiler(new_frame)
1327013273
pandas_frame = new_qc.to_pandas()
1327113274

13272-
# remove last column after first retrieving row count
13273-
row_count = 0 if 0 == len(pandas_frame) else pandas_frame.iat[0, -1]
13274-
pandas_frame = pandas_frame.iloc[:, :-1]
13275+
if use_cached_row_count:
13276+
row_count = frame.ordered_dataframe.row_count
13277+
else:
13278+
# remove last column after first retrieving row count
13279+
row_count = 0 if len(pandas_frame) == 0 else pandas_frame.iat[0, -1]
13280+
pandas_frame = pandas_frame.iloc[:, :-1]
1327513281
col_count = len(pandas_frame.columns)
1327613282

1327713283
return row_count, col_count, pandas_frame

tests/integ/modin/crosstab/test_crosstab.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def eval_func(args_list):
223223
def test_basic_crosstab_with_df_and_series_objs_pandas_errors_columns(
224224
self, dropna, a, b, c
225225
):
226-
query_count = 4
226+
query_count = 2
227227
join_count = 1 if dropna else 2
228228
a = native_pd.Series(
229229
a,
@@ -269,7 +269,7 @@ def eval_func(args_list):
269269
def test_basic_crosstab_with_df_and_series_objs_pandas_errors_index(
270270
self, dropna, a, b, c
271271
):
272-
query_count = 6
272+
query_count = 4
273273
join_count = 5 if dropna else 11
274274
a = native_pd.Series(
275275
a,
@@ -556,7 +556,7 @@ def test_values(self, dropna, aggfunc, basic_crosstab_dfs):
556556

557557
@pytest.mark.parametrize("aggfunc", AGGFUNCS_THAT_CANNOT_PRODUCE_NAN)
558558
def test_values_series_like(self, dropna, aggfunc, basic_crosstab_dfs):
559-
query_count = 5
559+
query_count = 3
560560
join_count = 2 if dropna else 3
561561
native_df, snow_df = basic_crosstab_dfs
562562

@@ -646,7 +646,7 @@ def test_values_unsupported_aggfunc(basic_crosstab_dfs):
646646
)
647647

648648

649-
@sql_count_checker(query_count=4)
649+
@sql_count_checker(query_count=2)
650650
def test_values_series_like_unsupported_aggfunc(basic_crosstab_dfs):
651651
# The query count above comes from building the DataFrame
652652
# that we pass in to pivot table.

tests/integ/modin/frame/test_empty.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
({"A": [np.nan]}, "np nan column"),
2828
],
2929
)
30-
@sql_count_checker(query_count=1)
30+
@sql_count_checker(query_count=0)
3131
def test_dataframe_empty_param(dataframe_input, test_case_name):
3232
eval_snowpark_pandas_result(
3333
pd.DataFrame(dataframe_input),

tests/integ/modin/frame/test_from_dict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def test_from_dict_orient_tight():
7272
)
7373

7474

75-
@sql_count_checker(query_count=7)
75+
@sql_count_checker(query_count=5)
7676
def test_from_dict_series_values():
7777
# TODO(SNOW-1857349): Proved a lazy implementation for this case
7878
data = {i: pd.Series(range(1)) for i in range(2)}

tests/integ/modin/frame/test_getitem.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def get_helper(df):
8484
else:
8585
return df[key]
8686

87-
# 5 extra queries for iter
88-
with SqlCounter(query_count=6 if isinstance(key, native_pd.Index) else 1):
87+
# 4 extra queries for iter
88+
with SqlCounter(query_count=5 if isinstance(key, native_pd.Index) else 1):
8989
eval_snowpark_pandas_result(
9090
default_index_snowpark_pandas_df,
9191
default_index_native_df,
@@ -119,8 +119,8 @@ def get_helper(df):
119119
native_df = native_pd.DataFrame(data)
120120
snowpark_df = pd.DataFrame(native_df)
121121

122-
# 5 extra queries for iter
123-
with SqlCounter(query_count=6 if isinstance(key, native_pd.Index) else 1):
122+
# 4 extra queries for iter
123+
with SqlCounter(query_count=5 if isinstance(key, native_pd.Index) else 1):
124124
eval_snowpark_pandas_result(
125125
snowpark_df,
126126
native_df,

tests/integ/modin/frame/test_iloc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -425,7 +425,7 @@ def test_df_iloc_get_empty_key(
425425
)
426426

427427

428-
@sql_count_checker(query_count=2)
428+
@sql_count_checker(query_count=1)
429429
def test_df_iloc_get_empty(empty_snowpark_pandas_df):
430430
_ = empty_snowpark_pandas_df.iloc[0]
431431

@@ -1811,8 +1811,8 @@ def test_df_iloc_set_with_row_key_list(
18111811
else:
18121812
snow_row_pos = row_pos
18131813

1814-
# 2 extra queries for iter
1815-
expected_query_count = 3 if isinstance(snow_row_pos, pd.Index) else 1
1814+
# 1 extra query for iter
1815+
expected_query_count = 2 if isinstance(snow_row_pos, pd.Index) else 1
18161816
expected_join_count = 2 if isinstance(item_values, int) else 3
18171817

18181818
with SqlCounter(query_count=expected_query_count, join_count=expected_join_count):

0 commit comments

Comments
 (0)