Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -425,13 +425,25 @@ def ensure_row_count_column(self) -> "OrderedDataFrame":
*self.projected_column_snowflake_quoted_identifiers,
count("*").over().as_(row_count_snowflake_quoted_identifier),
)

# inplace update so dataframe_ref can be shared. Note that we keep
# the original ordering columns.
ordered_dataframe.row_count_snowflake_quoted_identifier = (
row_count_snowflake_quoted_identifier
)
return ordered_dataframe

def materialize_row_count(self) -> int:
"""
Perform a query to retrieve the row count of this OrderedDataFrame.

Use this function in place of ensure_row_count_column() in scenarios where the extra
query is acceptable, and the embedded `COUNT(*) OVER()` window operation would be too expensive.
Performing a naked `COUNT(*)` and avoiding a potential window function or cross join is
more performance in these scenarios.
"""
return self._dataframe_ref.snowpark_dataframe.count()

def generate_snowflake_quoted_identifiers(
self,
*,
Expand Down Expand Up @@ -2020,3 +2032,30 @@ def sample(self, n: Optional[int], frac: Optional[float]) -> "OrderedDataFrame":
ordering_columns=self.ordering_columns,
)
)

def is_projection_of_table(self) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about nested projections? This seems to only handle a single level of projection, right? That should still be fine. We can address nested projections in a followup step.

Copy link
Contributor Author

@sfc-gh-rdurrani sfc-gh-rdurrani Dec 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before putting in this PR, I tested it locally and found that this will handle nested projections - e.g. I tried the following:

df = pd.read_snowflake...
df = df[df.columns[:5:-1]]
df  = df.select_dtypes()

and

df = pd.DataFrame(...)
df = df[df.columns[:5:-1]]
df  = df.select_dtypes()

and after each of those lines of code + after the entire block of code, the format of the api_calls method remained the same - i.e. this check will work for nested projections, and the metadata caching of count is passed on for nested projections of that type.

"""
Return whether or not the current OrderedDataFrame is simply a projection of a table.

Returns:
bool
True if the current OrderedDataFrame is simply a projection of a table. False if it represents
a more complex operation.
"""
# If we have only performed projections since creating this DataFrame, it will only contain
# 1 API call in the plan - either `Session.sql` for DataFrames based off of I/O operations
# e.g. `read_snowflake` or `read_csv`, or `Session.create_dataframe` for DataFrames created
# out of Python objects.
# We must also ensure that the underlying compiled query plan is only a single query --
# for example, a simple select on pd.DataFrame([1] * 2000) would result in CREATE TEMP TABLE
# + batch INSERT + DROP TABLE queries, which introduce non-trivial overhead.
snowpark_df = self._dataframe_ref.snowpark_dataframe
snowpark_plan = snowpark_df._plan
return (
len(snowpark_plan.api_calls) == 1
and len(snowpark_plan.queries) == 1
and any(
accepted_api in snowpark_plan.api_calls[0]["name"]
for accepted_api in ["Session.sql", "Session.create_dataframe"]
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -13050,67 +13050,81 @@ def build_repr_df(
`row_count` holds the number of rows the DataFrame has, `col_count` the number of columns the DataFrame has, and
the pandas dataset with `num_rows` or fewer rows and `num_cols` or fewer columns.
"""
# In order to issue less queries, use following trick:
# 1. add the row count column holding COUNT(*) OVER () over the snowpark dataframe
# 2. retrieve all columns
# 3. filter on rows with recursive count
# build_repr_df needs to know the row count of the underlying data, as the displayed representation will
# include the last few rows of the frame.
#
# To maximize performance, we use two distinct code paths.
# If the underlying OrderedDataFrame is a simple projection of a table:
# 1. Perform a query to retrieve the row count. This query will be cheap because the SQL engine can
# retrieve the value from table metadata.
# 2. Directly embed the row count into the filter query as a literal.
# If the underlying data is NOT a simple projection, we opt to perform fewer queries:
# 1. add the row count column holding COUNT(*) OVER () over the snowpark dataframe
# 2. retrieve all columns
# 3. filter on rows with recursive count

# Previously, 2 queries were issued, and a first version replaced them with a single query and a join
# the solution here uses a window function. This may lead to perf regressions, track these here SNOW-984177.
# Ensure that our reference to self._modin_frame is updated with cached row count and position.
self._modin_frame = (
self._modin_frame.ensure_row_position_column().ensure_row_count_column()
)
row_count_pandas_label = (
ROW_COUNT_COLUMN_LABEL
if len(self._modin_frame.data_column_pandas_index_names) == 1
else (ROW_COUNT_COLUMN_LABEL,)
* len(self._modin_frame.data_column_pandas_index_names)
)
frame_with_row_count_and_position = InternalFrame.create(
ordered_dataframe=self._modin_frame.ordered_dataframe,
data_column_pandas_labels=self._modin_frame.data_column_pandas_labels
+ [row_count_pandas_label],
data_column_snowflake_quoted_identifiers=self._modin_frame.data_column_snowflake_quoted_identifiers
+ [self._modin_frame.row_count_snowflake_quoted_identifier],
data_column_pandas_index_names=self._modin_frame.data_column_pandas_index_names,
index_column_pandas_labels=self._modin_frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=self._modin_frame.index_column_snowflake_quoted_identifiers,
data_column_types=self._modin_frame.cached_data_column_snowpark_pandas_types
+ [None],
index_column_types=self._modin_frame.cached_index_column_snowpark_pandas_types,
)
row_count_value = None
frame = self._modin_frame.ensure_row_position_column()
if self._modin_frame.ordered_dataframe.is_projection_of_table():
row_count_value = frame.ordered_dataframe.materialize_row_count()
row_count_expr = pandas_lit(row_count_value)
else:
frame = frame.ensure_row_count_column()
row_count_pandas_label = (
ROW_COUNT_COLUMN_LABEL
if len(frame.data_column_pandas_index_names) == 1
else (ROW_COUNT_COLUMN_LABEL,)
* len(frame.data_column_pandas_index_names)
)
frame = InternalFrame.create(
ordered_dataframe=frame.ordered_dataframe,
data_column_pandas_labels=frame.data_column_pandas_labels
+ [row_count_pandas_label],
data_column_snowflake_quoted_identifiers=frame.data_column_snowflake_quoted_identifiers
+ [frame.row_count_snowflake_quoted_identifier],
data_column_pandas_index_names=frame.data_column_pandas_index_names,
index_column_pandas_labels=frame.index_column_pandas_labels,
index_column_snowflake_quoted_identifiers=frame.index_column_snowflake_quoted_identifiers,
data_column_types=frame.cached_data_column_snowpark_pandas_types
+ [None],
index_column_types=frame.cached_index_column_snowpark_pandas_types,
)

row_count_identifier = (
frame_with_row_count_and_position.row_count_snowflake_quoted_identifier
)
row_count_expr = col(frame.row_count_snowflake_quoted_identifier)
row_position_snowflake_quoted_identifier = (
frame_with_row_count_and_position.row_position_snowflake_quoted_identifier
frame.row_position_snowflake_quoted_identifier
)

# filter frame based on num_rows.
# always return all columns as this may also result in a query.
# in the future could analyze plan to see whether retrieving column count would trigger a query, if not
# simply filter out based on static schema
num_rows_for_head_and_tail = num_rows_to_display // 2 + 1
new_frame = frame_with_row_count_and_position.filter(
new_frame = frame.filter(
(
col(row_position_snowflake_quoted_identifier)
<= num_rows_for_head_and_tail
)
| (
col(row_position_snowflake_quoted_identifier)
>= col(row_count_identifier) - num_rows_for_head_and_tail
>= row_count_expr - num_rows_for_head_and_tail
)
)

# retrieve frame as pandas object
new_qc = SnowflakeQueryCompiler(new_frame)
pandas_frame = new_qc.to_pandas()

# remove last column after first retrieving row count
row_count = 0 if 0 == len(pandas_frame) else pandas_frame.iat[0, -1]
pandas_frame = pandas_frame.iloc[:, :-1]
if row_count_value is None:
# if we appended the row count column instead of directly doing a COUNT(*), splice it off
# remove last column after first retrieving row count
row_count = 0 if 0 == len(pandas_frame) else pandas_frame.iat[0, -1]
pandas_frame = pandas_frame.iloc[:, :-1]
else:
row_count = row_count_value
col_count = len(pandas_frame.columns)

return row_count, col_count, pandas_frame
Expand Down
58 changes: 41 additions & 17 deletions tests/integ/modin/frame/test_repr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
from tests.integ.utils.sql_counter import SqlCounter, sql_count_checker

# expected_query_count is for test_repr_html paramterized SqlCounter test
# if the input data would be retrieved by a simple select, an additional query + select is
# incurred to eagerly retrieve the row count (larger data may avoid this because the CREATE TEMP
# TABLE + batch INSERT + DROP TABLE queries are likely more expensive than the COUNT(*) OVER() window
# function)
_DATAFRAMES_TO_TEST = [
(
native_pd.DataFrame(
Expand All @@ -29,60 +33,74 @@
],
}
),
1,
2,
2,
),
(
native_pd.DataFrame([1, 2], index=[pd.Timedelta(1), pd.Timedelta(-1)]),
1,
2,
2,
),
(
IRIS_DF,
4,
1,
),
(
native_pd.DataFrame(),
1,
2,
2,
),
(
native_pd.DataFrame(
{"A": list(range(10000)), "B": np.random.normal(size=10000)}
),
4,
1,
),
(
native_pd.DataFrame(columns=["A", "B", "C", "D", "C", "B", "A"]),
1,
2,
2,
),
# one large dataframe to test many columns
(
native_pd.DataFrame(columns=[f"x{i}" for i in range(300)]),
1,
2,
2,
),
# one large dataframe to test both columns/rows
(
native_pd.DataFrame(
data=np.zeros(shape=(300, 300)), columns=[f"x{i}" for i in range(300)]
),
4,
1,
),
]


@pytest.mark.parametrize("native_df, expected_query_count", _DATAFRAMES_TO_TEST)
def test_repr(native_df, expected_query_count):
@pytest.mark.parametrize(
"native_df, expected_query_count, expected_select_count", _DATAFRAMES_TO_TEST
)
def test_repr(native_df, expected_query_count, expected_select_count):
snow_df = pd.DataFrame(native_df)

native_str = repr(native_df)
# only measure select statements here, creation of dfs may yield a couple
# CREATE TEMPORARY TABLE/INSERT INTO queries
with SqlCounter(query_count=expected_query_count, select_count=1):
with SqlCounter(
query_count=expected_query_count, select_count=expected_select_count
):
snow_str = repr(snow_df)

assert native_str == snow_str


@pytest.mark.parametrize("native_df, expected_query_count", _DATAFRAMES_TO_TEST)
def test_repr_html(native_df, expected_query_count):
@pytest.mark.parametrize(
"native_df, expected_query_count, expected_select_count", _DATAFRAMES_TO_TEST
)
def test_repr_html(native_df, expected_query_count, expected_select_count):

# TODO: SNOW-916596 Test this with Jupyter notebooks.
# joins due to temp table creation
Expand All @@ -97,7 +115,9 @@ def test_repr_html(native_df, expected_query_count):
native_html = native_df._repr_html_()

# 10 of these are related to stored procs, inserts, alter session query tag.
with SqlCounter(query_count=expected_query_count, select_count=1):
with SqlCounter(
query_count=expected_query_count, select_count=expected_select_count
):
snow_html = snow_df._repr_html_()

assert native_html == snow_html
Expand Down Expand Up @@ -133,19 +153,23 @@ def queries(self) -> list[QueryRecord]:
return [query.sql_text for query in self._queries]


@pytest.mark.parametrize("native_df, expected_query_count", _DATAFRAMES_TO_TEST)
def test_repr_and_repr_html_issue_same_query(native_df, expected_query_count):
"""This test ensures that the same query is issued for both `repr` and `repr_html`
def test_repr_and_repr_html_issue_same_query():
"""
This test ensures that the same query is issued for both `repr` and `repr_html`
in order to take advantage of Snowflake server side caching when both are called back
to back (as in the case with displaying a DataFrame in a Jupyter notebook)."""
to back (as in the case with displaying a DataFrame in a Jupyter notebook).

If the input frame was a simple projection that results in an additional COUNT() query,
it is not captured by the ReprQueryListener.
"""

native_df = native_pd.DataFrame([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
snow_df = pd.DataFrame(native_df)

with ReprQueryListener(pd.session) as listener:
with SqlCounter(query_count=1, select_count=1):
with SqlCounter(query_count=2, select_count=2):
repr_str = repr(snow_df)
with SqlCounter(query_count=1, select_count=1):
with SqlCounter(query_count=2, select_count=2):
repr_html = snow_df._repr_html_()

assert repr_str == repr(native_df)
Expand Down
Loading