Skip to content

Commit 60ee76c

Browse files
authored
SNOW-1794362: Support df.summary (#2837)
<!--- Please answer these questions before creating your pull request. Thanks! ---> 1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. <!--- In this section, please add a Snowflake Jira issue number. Note that if a corresponding GitHub issue exists, you should still include the Snowflake Jira issue number. For example, for GitHub issue #1400, you should add "SNOW-1335071" here. ---> Fixes SNOW-1794362 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development) 3. Please describe how your code solves the related issue. Please write a short description of how your code change solves the related issue.
1 parent 2c71f27 commit 60ee76c

File tree

4 files changed

+314
-55
lines changed

4 files changed

+314
-55
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#### New Features
88

9+
- Added support for `DataFrame.summary()` to compute desired statistics of a DataFrame.
910
- Added support for the following functions in `functions.py`
1011
- `array_reverse`
1112
- `divnull`

docs/source/snowpark/dataframe.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ DataFrame
8282
DataFrame.show
8383
DataFrame.sort
8484
DataFrame.subtract
85+
DataFrame.summary
8586
DataFrame.take
8687
DataFrame.toDF
8788
DataFrame.toJSON

src/snowflake/snowpark/dataframe.py

Lines changed: 146 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55

66
import copy
7+
import functools
78
import itertools
89
import re
910
import sys
@@ -170,6 +171,7 @@
170171
from snowflake.snowpark.exceptions import SnowparkDataframeException
171172
from snowflake.snowpark.functions import (
172173
abs as abs_,
174+
approx_percentile,
173175
col,
174176
count,
175177
lit,
@@ -5046,44 +5048,15 @@ def session(self) -> "snowflake.snowpark.Session":
50465048
"""
50475049
return self._session
50485050

5049-
@publicapi
5050-
def describe(
5051-
self, *cols: Union[str, List[str]], _emit_ast: bool = True
5051+
def _calculate_statistics(
5052+
self,
5053+
cols: List[str],
5054+
stat_func_dict: Dict[str, Callable],
50525055
) -> "DataFrame":
50535056
"""
5054-
Computes basic statistics for numeric columns, which includes
5055-
``count``, ``mean``, ``stddev``, ``min``, and ``max``. If no columns
5056-
are provided, this function computes statistics for all numerical or
5057-
string columns. Non-numeric and non-string columns will be ignored
5058-
when calling this method.
5059-
5060-
Example::
5061-
>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
5062-
>>> desc_result = df.describe().sort("SUMMARY").show()
5063-
-------------------------------------------------------
5064-
|"SUMMARY" |"A" |"B" |
5065-
-------------------------------------------------------
5066-
|count |2.0 |2.0 |
5067-
|max |3.0 |4.0 |
5068-
|mean |2.0 |3.0 |
5069-
|min |1.0 |2.0 |
5070-
|stddev |1.4142135623730951 |1.4142135623730951 |
5071-
-------------------------------------------------------
5072-
<BLANKLINE>
5073-
5074-
Args:
5075-
cols: The names of columns whose basic statistics are computed.
5057+
Calculates the statistics for the specified columns.
5058+
This method is used for the implementation of the `describe` and `summary` method.
50765059
"""
5077-
stmt = None
5078-
if _emit_ast:
5079-
stmt = self._session._ast_batch.assign()
5080-
expr = with_src_position(stmt.expr.sp_dataframe_describe, stmt)
5081-
self._set_ast_ref(expr.df)
5082-
col_list, expr.cols.variadic = parse_positional_args_to_list_variadic(*cols)
5083-
for c in col_list:
5084-
build_expr_from_snowpark_column_or_col_name(expr.cols.args.add(), c)
5085-
5086-
cols = parse_positional_args_to_list(*cols)
50875060
df = self.select(cols, _emit_ast=False) if len(cols) > 0 else self
50885061

50895062
# ignore non-numeric and non-string columns
@@ -5093,30 +5066,11 @@ def describe(
50935066
if isinstance(field.datatype, (StringType, _NumericType))
50945067
}
50955068

5096-
stat_func_dict = {
5097-
"count": count,
5098-
"mean": mean,
5099-
"stddev": stddev,
5100-
"min": min_,
5101-
"max": max_,
5102-
}
5103-
51045069
# if no columns should be selected, just return stat names
51055070
if len(numerical_string_col_type_dict) == 0:
51065071
df = self._session.create_dataframe(
51075072
list(stat_func_dict.keys()), schema=["summary"], _emit_ast=False
51085073
)
5109-
# We need to set the API calls for this to same API calls for describe
5110-
# Also add the new API calls for creating this DataFrame to the describe subcalls
5111-
adjust_api_subcalls(
5112-
df,
5113-
"DataFrame.describe",
5114-
precalls=self._plan.api_calls,
5115-
subcalls=df._plan.api_calls,
5116-
)
5117-
5118-
if _emit_ast:
5119-
df._ast_id = stmt.var_id.bitfield1
51205074

51215075
return df
51225076

@@ -5128,7 +5082,7 @@ def describe(
51285082
# for string columns, we need to convert all stats to string
51295083
# such that they can be fitted into one column
51305084
if isinstance(t, StringType):
5131-
if name in ["mean", "stddev"]:
5085+
if name.lower() in ["mean", "stddev"] or name.endswith("%"):
51325086
agg_cols.append(to_char(func(lit(None))).as_(c))
51335087
else:
51345088
agg_cols.append(to_char(func(c)))
@@ -5147,6 +5101,55 @@ def describe(
51475101
res_df.union(agg_stat_df, _emit_ast=False) if res_df else agg_stat_df
51485102
)
51495103

5104+
return res_df
5105+
5106+
@publicapi
5107+
def describe(
5108+
self, *cols: Union[str, List[str]], _emit_ast: bool = True
5109+
) -> "DataFrame":
5110+
"""
5111+
Computes basic statistics for numeric columns, which includes
5112+
``count``, ``mean``, ``stddev``, ``min``, and ``max``. If no columns
5113+
are provided, this function computes statistics for all numerical or
5114+
string columns. Non-numeric and non-string columns will be ignored
5115+
when calling this method.
5116+
5117+
Example::
5118+
>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
5119+
>>> desc_result = df.describe().sort("SUMMARY").show()
5120+
-------------------------------------------------------
5121+
|"SUMMARY" |"A" |"B" |
5122+
-------------------------------------------------------
5123+
|count |2.0 |2.0 |
5124+
|max |3.0 |4.0 |
5125+
|mean |2.0 |3.0 |
5126+
|min |1.0 |2.0 |
5127+
|stddev |1.4142135623730951 |1.4142135623730951 |
5128+
-------------------------------------------------------
5129+
<BLANKLINE>
5130+
5131+
Args:
5132+
cols: The names of columns whose basic statistics are computed.
5133+
"""
5134+
stmt = None
5135+
if _emit_ast:
5136+
stmt = self._session._ast_batch.assign()
5137+
expr = with_src_position(stmt.expr.sp_dataframe_describe, stmt)
5138+
self._set_ast_ref(expr.df)
5139+
col_list, expr.cols.variadic = parse_positional_args_to_list_variadic(*cols)
5140+
for c in col_list:
5141+
build_expr_from_snowpark_column_or_col_name(expr.cols.args.add(), c)
5142+
5143+
stat_func_dict = {
5144+
"count": count,
5145+
"mean": mean,
5146+
"stddev": stddev,
5147+
"min": min_,
5148+
"max": max_,
5149+
}
5150+
cols = parse_positional_args_to_list(*cols)
5151+
res_df = self._calculate_statistics(cols, stat_func_dict)
5152+
51505153
adjust_api_subcalls(
51515154
res_df,
51525155
"DataFrame.describe",
@@ -5159,6 +5162,94 @@ def describe(
51595162

51605163
return res_df
51615164

5165+
@publicapi
5166+
def summary(self, *statistics: str, _emit_ast: bool = True) -> "DataFrame":
5167+
"""
5168+
Computes specified statistics for all numeric and string columns.
5169+
Non-numeric and non-string columns will be ignored when calling this method.
5170+
5171+
Available statistics are: ``count``, ``mean``, ``stddev``, ``min``, ``max`` and
5172+
arbitrary approximate percentiles specified as a percentage (e.g., 75%).
5173+
5174+
If no statistics are given, this function computes ``count``, ``mean``, ``stddev``, ``min``,
5175+
approximate quartiles (percentiles at 25%, 50%, and 75%), and ``max``.
5176+
5177+
If no columns are provided, this function computes statistics for all numerical or
5178+
string columns.
5179+
5180+
Example::
5181+
>>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
5182+
>>> desc_result = df.summary().sort("SUMMARY").show()
5183+
-------------------------------------------------------
5184+
|"SUMMARY" |"A" |"B" |
5185+
-------------------------------------------------------
5186+
|25% |1.5 |2.5 |
5187+
|50% |2.0 |3.0 |
5188+
|75% |2.5 |3.5 |
5189+
|count |2.0 |2.0 |
5190+
|max |3.0 |4.0 |
5191+
|mean |2.0 |3.0 |
5192+
|min |1.0 |2.0 |
5193+
|stddev |1.4142135623730951 |1.4142135623730951 |
5194+
-------------------------------------------------------
5195+
<BLANKLINE>
5196+
5197+
Args:
5198+
statistics: The names of columns whose basic statistics are computed.
5199+
"""
5200+
# get stats that we want to calculate
5201+
stat_func_dict = {}
5202+
for s in statistics:
5203+
if s.lower() == "count":
5204+
stat_func_dict[s] = count
5205+
elif s.lower() == "mean":
5206+
stat_func_dict[s] = mean
5207+
elif s.lower() == "stddev":
5208+
stat_func_dict[s] = stddev
5209+
elif s.lower() == "min":
5210+
stat_func_dict[s] = min_
5211+
elif s.lower() == "max":
5212+
stat_func_dict[s] = max_
5213+
elif s.endswith("%"):
5214+
try:
5215+
number = float(s[:-1])
5216+
except Exception as ex:
5217+
raise ValueError(f"Unable to parse {s} as a percentile: {ex}.")
5218+
if number < 0 or number > 100:
5219+
raise ValueError(
5220+
"requirement failed: Percentiles must be in the range [0, 1]."
5221+
)
5222+
stat_func_dict[s] = functools.partial(
5223+
approx_percentile, percentile=number / 100
5224+
)
5225+
else:
5226+
raise ValueError(f"{s} is not a recognised statistic.")
5227+
5228+
# if stats are not specified, use the following default stats
5229+
if not stat_func_dict:
5230+
stat_func_dict = {
5231+
"count": count,
5232+
"mean": mean,
5233+
"stddev": stddev,
5234+
"min": min_,
5235+
"25%": lambda c: approx_percentile(c, 0.25),
5236+
"50%": lambda c: approx_percentile(c, 0.50),
5237+
"75%": lambda c: approx_percentile(c, 0.75),
5238+
"max": max_,
5239+
}
5240+
5241+
# calculate stats on all columns
5242+
res_df = self._calculate_statistics([], stat_func_dict)
5243+
5244+
adjust_api_subcalls(
5245+
res_df,
5246+
"DataFrame.summary",
5247+
precalls=self._plan.api_calls,
5248+
subcalls=res_df._plan.api_calls.copy(),
5249+
)
5250+
5251+
return res_df
5252+
51625253
@df_api_usage
51635254
@publicapi
51645255
def rename(

0 commit comments

Comments
 (0)