Skip to content

Commit 3337f40

Browse files
committed
materialize row count for simple projection repr
1 parent df51d38 commit 3337f40

File tree

2 files changed

+62
-35
lines changed

2 files changed

+62
-35
lines changed

src/snowflake/snowpark/modin/plugin/_internal/ordered_dataframe.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,17 @@ def ensure_row_count_column(self) -> "OrderedDataFrame":
441441
)
442442
return ordered_dataframe
443443

444+
def materialize_row_count(self) -> int:
445+
"""
446+
Perform a query to retrieve the row count of this OrderedDataFrame.
447+
448+
Use this function in place of ensure_row_count_column() in scenarios where the extra
449+
query is acceptable, and the embedded `COUNT(*) OVER()` window operation would be too expensive.
450+
Performing a naked `COUNT(*)` and avoiding a potential window function or cross join is
451+
more performance in these scenarios.
452+
"""
453+
return self._dataframe_ref.snowpark_dataframe.count()
454+
444455
def generate_snowflake_quoted_identifiers(
445456
self,
446457
*,

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 51 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13058,67 +13058,83 @@ def build_repr_df(
1305813058
`row_count` holds the number of rows the DataFrame has, `col_count` the number of columns the DataFrame has, and
1305913059
the pandas dataset with `num_rows` or fewer rows and `num_cols` or fewer columns.
1306013060
"""
13061-
# In order to issue less queries, use following trick:
13062-
# 1. add the row count column holding COUNT(*) OVER () over the snowpark dataframe
13063-
# 2. retrieve all columns
13064-
# 3. filter on rows with recursive count
13061+
# build_repr_df needs to know the row count of the underlying data, as the displayed representation will
13062+
# include the last few rows of the frame.
13063+
#
13064+
# To maximize performance, we use two distinct code paths.
13065+
# If the underlying OrderedDataFrame is a simple projection of a table:
13066+
# 1. Perform a query to retrieve the row count. This query will be cheap because the SQL engine can
13067+
# retrieve the value from table metadata.
13068+
# 2. Directly embed the row count into the filter query as a literal.
13069+
# If the underlying data is NOT a simple projection, we opt to perform fewer queries:
13070+
# 1. add the row count column holding COUNT(*) OVER () over the snowpark dataframe
13071+
# 2. retrieve all columns
13072+
# 3. filter on rows with recursive count
1306513073

1306613074
# Previously, 2 queries were issued, and a first version replaced them with a single query and a join
1306713075
# the solution here uses a window function. This may lead to perf regressions, track these here SNOW-984177.
1306813076
# Ensure that our reference to self._modin_frame is updated with cached row count and position.
13069-
self._modin_frame = (
13070-
self._modin_frame.ensure_row_position_column().ensure_row_count_column()
13071-
)
13072-
row_count_pandas_label = (
13073-
ROW_COUNT_COLUMN_LABEL
13074-
if len(self._modin_frame.data_column_pandas_index_names) == 1
13075-
else (ROW_COUNT_COLUMN_LABEL,)
13076-
* len(self._modin_frame.data_column_pandas_index_names)
13077-
)
13078-
frame_with_row_count_and_position = InternalFrame.create(
13079-
ordered_dataframe=self._modin_frame.ordered_dataframe,
13080-
data_column_pandas_labels=self._modin_frame.data_column_pandas_labels
13081-
+ [row_count_pandas_label],
13082-
data_column_snowflake_quoted_identifiers=self._modin_frame.data_column_snowflake_quoted_identifiers
13083-
+ [self._modin_frame.row_count_snowflake_quoted_identifier],
13084-
data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names,
13085-
index_column_pandas_labels=self._modin_frame.index_column_pandas_labels,
13086-
index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers,
13087-
data_column_types=self._modin_frame.cached_data_column_snowpark_pandas_types
13088-
+ [None],
13089-
index_column_types=self._modin_frame.cached_index_column_snowpark_pandas_types,
13090-
)
13077+
row_count_value = None
13078+
if self._modin_frame.ordered_dataframe.is_projection_of_table():
13079+
frame = self._modin_frame.ensure_row_position_column()
13080+
row_count_value = frame.ordered_dataframe.materialize_row_count()
13081+
row_count_expr = pandas_lit(row_count_value)
13082+
else:
13083+
frame = (
13084+
self._modin_frame.ensure_row_position_column().ensure_row_count_column()
13085+
)
13086+
row_count_pandas_label = (
13087+
ROW_COUNT_COLUMN_LABEL
13088+
if len(frame.data_column_pandas_index_names) == 1
13089+
else (ROW_COUNT_COLUMN_LABEL,)
13090+
* len(frame.data_column_pandas_index_names)
13091+
)
13092+
frame = InternalFrame.create(
13093+
ordered_dataframe=frame.ordered_dataframe,
13094+
data_column_pandas_labels=frame.data_column_pandas_labels
13095+
+ [row_count_pandas_label],
13096+
data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers
13097+
+ [frame.row_count_snowflake_quoted_identifier],
13098+
data_column_pandas_index_names=frame.data_column_pandas_index_names,
13099+
index_column_pandas_labels=frame.index_column_pandas_labels,
13100+
index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers,
13101+
data_column_types=frame.cached_data_column_snowpark_pandas_types
13102+
+ [None],
13103+
index_column_types=frame.cached_index_column_snowpark_pandas_types,
13104+
)
1309113105

13092-
row_count_identifier = (
13093-
frame_with_row_count_and_position.row_count_snowflake_quoted_identifier
13094-
)
13106+
row_count_expr = col(frame.row_count_snowflake_quoted_identifier)
1309513107
row_position_snowflake_quoted_identifier = (
13096-
frame_with_row_count_and_position.row_position_snowflake_quoted_identifier
13108+
frame.row_position_snowflake_quoted_identifier
1309713109
)
1309813110

1309913111
# filter frame based on num_rows.
1310013112
# always return all columns as this may also result in a query.
1310113113
# in the future could analyze plan to see whether retrieving column count would trigger a query, if not
1310213114
# simply filter out based on static schema
1310313115
num_rows_for_head_and_tail = num_rows_to_display // 2 + 1
13104-
new_frame = frame_with_row_count_and_position.filter(
13116+
new_frame = frame.filter(
1310513117
(
1310613118
col(row_position_snowflake_quoted_identifier)
1310713119
<= num_rows_for_head_and_tail
1310813120
)
1310913121
| (
1311013122
col(row_position_snowflake_quoted_identifier)
13111-
>= col(row_count_identifier) - num_rows_for_head_and_tail
13123+
>= row_count_expr - num_rows_for_head_and_tail
1311213124
)
1311313125
)
1311413126

1311513127
# retrieve frame as pandas object
1311613128
new_qc = SnowflakeQueryCompiler(new_frame)
1311713129
pandas_frame = new_qc.to_pandas()
1311813130

13119-
# remove last column after first retrieving row count
13120-
row_count = 0 if 0 == len(pandas_frame) else pandas_frame.iat[0, -1]
13121-
pandas_frame = pandas_frame.iloc[:, :-1]
13131+
if row_count_value is None:
13132+
# if we appended the row count column instead of directly doing a COUNT(*), splice it off
13133+
# remove last column after first retrieving row count
13134+
row_count = 0 if 0 == len(pandas_frame) else pandas_frame.iat[0, -1]
13135+
pandas_frame = pandas_frame.iloc[:, :-1]
13136+
else:
13137+
row_count = row_count_value
1312213138
col_count = len(pandas_frame.columns)
1312313139

1312413140
return row_count, col_count, pandas_frame

0 commit comments

Comments
 (0)