@@ -13058,67 +13058,83 @@ def build_repr_df(
1305813058 `row_count` holds the number of rows the DataFrame has, `col_count` the number of columns the DataFrame has, and
1305913059 the pandas dataset with `num_rows` or fewer rows and `num_cols` or fewer columns.
1306013060 """
13061- # In order to issue less queries, use following trick:
13062- # 1. add the row count column holding COUNT(*) OVER () over the snowpark dataframe
13063- # 2. retrieve all columns
13064- # 3. filter on rows with recursive count
13061+ # build_repr_df needs to know the row count of the underlying data, as the displayed representation will
13062+ # include the last few rows of the frame.
13063+ #
13064+ # To maximize performance, we use two distinct code paths.
13065+ # If the underlying OrderedDataFrame is a simple projection of a table:
13066+ # 1. Perform a query to retrieve the row count. This query will be cheap because the SQL engine can
13067+ # retrieve the value from table metadata.
13068+ # 2. Directly embed the row count into the filter query as a literal.
13069+ # If the underlying data is NOT a simple projection, we opt to perform fewer queries:
13070+ # 1. add the row count column holding COUNT(*) OVER () over the snowpark dataframe
13071+ # 2. retrieve all columns
13072+ # 3. filter on rows with recursive count
1306513073
1306613074 # Previously, 2 queries were issued, and a first version replaced them with a single query and a join
1306713075 # the solution here uses a window function. This may lead to perf regressions, track these here SNOW-984177.
1306813076 # Ensure that our reference to self._modin_frame is updated with cached row count and position.
13069- self._modin_frame = (
13070- self._modin_frame.ensure_row_position_column().ensure_row_count_column()
13071- )
13072- row_count_pandas_label = (
13073- ROW_COUNT_COLUMN_LABEL
13074- if len(self._modin_frame.data_column_pandas_index_names) == 1
13075- else (ROW_COUNT_COLUMN_LABEL,)
13076- * len(self._modin_frame.data_column_pandas_index_names)
13077- )
13078- frame_with_row_count_and_position = InternalFrame.create(
13079- ordered_dataframe=self._modin_frame.ordered_dataframe,
13080- data_column_pandas_labels=self._modin_frame.data_column_pandas_labels
13081- + [row_count_pandas_label],
13082- data_column_snowflake_quoted_identifiers=self._modin_frame.data_column_snowflake_quoted_identifiers
13083- + [self._modin_frame.row_count_snowflake_quoted_identifier],
13084- data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names,
13085- index_column_pandas_labels=self._modin_frame.index_column_pandas_labels,
13086- index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers,
13087- data_column_types=self._modin_frame.cached_data_column_snowpark_pandas_types
13088- + [None],
13089- index_column_types=self._modin_frame.cached_index_column_snowpark_pandas_types,
13090- )
13077+ row_count_value = None
13078+ if self._modin_frame.ordered_dataframe.is_projection_of_table():
13079+ frame = self._modin_frame.ensure_row_position_column()
13080+ row_count_value = frame.ordered_dataframe.materialize_row_count()
13081+ row_count_expr = pandas_lit(row_count_value)
13082+ else:
13083+ frame = (
13084+ self._modin_frame.ensure_row_position_column().ensure_row_count_column()
13085+ )
13086+ row_count_pandas_label = (
13087+ ROW_COUNT_COLUMN_LABEL
13088+ if len(frame.data_column_pandas_index_names) == 1
13089+ else (ROW_COUNT_COLUMN_LABEL,)
13090+ * len(frame.data_column_pandas_index_names)
13091+ )
13092+ frame = InternalFrame.create(
13093+ ordered_dataframe=frame.ordered_dataframe,
13094+ data_column_pandas_labels=frame.data_column_pandas_labels
13095+ + [row_count_pandas_label],
13096+ data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers
13097+ + [frame.row_count_snowflake_quoted_identifier],
13098+ data_column_pandas_index_names=frame.data_column_pandas_index_names,
13099+ index_column_pandas_labels=frame.index_column_pandas_labels,
13100+ index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers,
13101+ data_column_types=frame.cached_data_column_snowpark_pandas_types
13102+ + [None],
13103+ index_column_types=frame.cached_index_column_snowpark_pandas_types,
13104+ )
1309113105
13092- row_count_identifier = (
13093- frame_with_row_count_and_position.row_count_snowflake_quoted_identifier
13094- )
13106+ row_count_expr = col(frame.row_count_snowflake_quoted_identifier)
1309513107 row_position_snowflake_quoted_identifier = (
13096- frame_with_row_count_and_position .row_position_snowflake_quoted_identifier
13108+ frame .row_position_snowflake_quoted_identifier
1309713109 )
1309813110
1309913111 # filter frame based on num_rows.
1310013112 # always return all columns as this may also result in a query.
1310113113 # in the future could analyze plan to see whether retrieving column count would trigger a query, if not
1310213114 # simply filter out based on static schema
1310313115 num_rows_for_head_and_tail = num_rows_to_display // 2 + 1
13104- new_frame = frame_with_row_count_and_position .filter(
13116+ new_frame = frame .filter(
1310513117 (
1310613118 col(row_position_snowflake_quoted_identifier)
1310713119 <= num_rows_for_head_and_tail
1310813120 )
1310913121 | (
1311013122 col(row_position_snowflake_quoted_identifier)
13111- >= col(row_count_identifier) - num_rows_for_head_and_tail
13123+ >= row_count_expr - num_rows_for_head_and_tail
1311213124 )
1311313125 )
1311413126
1311513127 # retrieve frame as pandas object
1311613128 new_qc = SnowflakeQueryCompiler(new_frame)
1311713129 pandas_frame = new_qc.to_pandas()
1311813130
13119- # remove last column after first retrieving row count
13120- row_count = 0 if 0 == len(pandas_frame) else pandas_frame.iat[0, -1]
13121- pandas_frame = pandas_frame.iloc[:, :-1]
13131+ if row_count_value is None:
13132+ # if we appended the row count column instead of directly doing a COUNT(*), splice it off
13133+ # remove last column after first retrieving row count
13134+ row_count = 0 if 0 == len(pandas_frame) else pandas_frame.iat[0, -1]
13135+ pandas_frame = pandas_frame.iloc[:, :-1]
13136+ else:
13137+ row_count = row_count_value
1312213138 col_count = len(pandas_frame.columns)
1312313139
1312413140 return row_count, col_count, pandas_frame
0 commit comments