44#
55
66import copy
7+ import functools
78import itertools
89import re
910import sys
170171from snowflake .snowpark .exceptions import SnowparkDataframeException
171172from snowflake .snowpark .functions import (
172173 abs as abs_ ,
174+ approx_percentile ,
173175 col ,
174176 count ,
175177 lit ,
@@ -5046,44 +5048,15 @@ def session(self) -> "snowflake.snowpark.Session":
50465048 """
50475049 return self ._session
50485050
5049- @publicapi
5050- def describe (
5051- self , * cols : Union [str , List [str ]], _emit_ast : bool = True
5051+ def _calculate_statistics (
5052+ self ,
5053+ cols : List [str ],
5054+ stat_func_dict : Dict [str , Callable ],
50525055 ) -> "DataFrame" :
50535056 """
5054- Computes basic statistics for numeric columns, which includes
5055- ``count``, ``mean``, ``stddev``, ``min``, and ``max``. If no columns
5056- are provided, this function computes statistics for all numerical or
5057- string columns. Non-numeric and non-string columns will be ignored
5058- when calling this method.
5059-
5060- Example::
5061- >>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
5062- >>> desc_result = df.describe().sort("SUMMARY").show()
5063- -------------------------------------------------------
5064- |"SUMMARY" |"A" |"B" |
5065- -------------------------------------------------------
5066- |count |2.0 |2.0 |
5067- |max |3.0 |4.0 |
5068- |mean |2.0 |3.0 |
5069- |min |1.0 |2.0 |
5070- |stddev |1.4142135623730951 |1.4142135623730951 |
5071- -------------------------------------------------------
5072- <BLANKLINE>
5073-
5074- Args:
5075- cols: The names of columns whose basic statistics are computed.
5057+ Calculates the statistics for the specified columns.
5058+ This method is used for the implementation of the `describe` and `summary` method.
50765059 """
5077- stmt = None
5078- if _emit_ast :
5079- stmt = self ._session ._ast_batch .assign ()
5080- expr = with_src_position (stmt .expr .sp_dataframe_describe , stmt )
5081- self ._set_ast_ref (expr .df )
5082- col_list , expr .cols .variadic = parse_positional_args_to_list_variadic (* cols )
5083- for c in col_list :
5084- build_expr_from_snowpark_column_or_col_name (expr .cols .args .add (), c )
5085-
5086- cols = parse_positional_args_to_list (* cols )
50875060 df = self .select (cols , _emit_ast = False ) if len (cols ) > 0 else self
50885061
50895062 # ignore non-numeric and non-string columns
@@ -5093,30 +5066,11 @@ def describe(
50935066 if isinstance (field .datatype , (StringType , _NumericType ))
50945067 }
50955068
5096- stat_func_dict = {
5097- "count" : count ,
5098- "mean" : mean ,
5099- "stddev" : stddev ,
5100- "min" : min_ ,
5101- "max" : max_ ,
5102- }
5103-
51045069 # if no columns should be selected, just return stat names
51055070 if len (numerical_string_col_type_dict ) == 0 :
51065071 df = self ._session .create_dataframe (
51075072 list (stat_func_dict .keys ()), schema = ["summary" ], _emit_ast = False
51085073 )
5109- # We need to set the API calls for this to same API calls for describe
5110- # Also add the new API calls for creating this DataFrame to the describe subcalls
5111- adjust_api_subcalls (
5112- df ,
5113- "DataFrame.describe" ,
5114- precalls = self ._plan .api_calls ,
5115- subcalls = df ._plan .api_calls ,
5116- )
5117-
5118- if _emit_ast :
5119- df ._ast_id = stmt .var_id .bitfield1
51205074
51215075 return df
51225076
@@ -5128,7 +5082,7 @@ def describe(
51285082 # for string columns, we need to convert all stats to string
51295083 # such that they can be fitted into one column
51305084 if isinstance (t , StringType ):
5131- if name in ["mean" , "stddev" ]:
5085+ if name . lower () in ["mean" , "stddev" ] or name . endswith ( "%" ) :
51325086 agg_cols .append (to_char (func (lit (None ))).as_ (c ))
51335087 else :
51345088 agg_cols .append (to_char (func (c )))
@@ -5147,6 +5101,55 @@ def describe(
51475101 res_df .union (agg_stat_df , _emit_ast = False ) if res_df else agg_stat_df
51485102 )
51495103
5104+ return res_df
5105+
5106+ @publicapi
5107+ def describe (
5108+ self , * cols : Union [str , List [str ]], _emit_ast : bool = True
5109+ ) -> "DataFrame" :
5110+ """
5111+ Computes basic statistics for numeric columns, which includes
5112+ ``count``, ``mean``, ``stddev``, ``min``, and ``max``. If no columns
5113+ are provided, this function computes statistics for all numerical or
5114+ string columns. Non-numeric and non-string columns will be ignored
5115+ when calling this method.
5116+
5117+ Example::
5118+ >>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
5119+ >>> desc_result = df.describe().sort("SUMMARY").show()
5120+ -------------------------------------------------------
5121+ |"SUMMARY" |"A" |"B" |
5122+ -------------------------------------------------------
5123+ |count |2.0 |2.0 |
5124+ |max |3.0 |4.0 |
5125+ |mean |2.0 |3.0 |
5126+ |min |1.0 |2.0 |
5127+ |stddev |1.4142135623730951 |1.4142135623730951 |
5128+ -------------------------------------------------------
5129+ <BLANKLINE>
5130+
5131+ Args:
5132+ cols: The names of columns whose basic statistics are computed.
5133+ """
5134+ stmt = None
5135+ if _emit_ast :
5136+ stmt = self ._session ._ast_batch .assign ()
5137+ expr = with_src_position (stmt .expr .sp_dataframe_describe , stmt )
5138+ self ._set_ast_ref (expr .df )
5139+ col_list , expr .cols .variadic = parse_positional_args_to_list_variadic (* cols )
5140+ for c in col_list :
5141+ build_expr_from_snowpark_column_or_col_name (expr .cols .args .add (), c )
5142+
5143+ stat_func_dict = {
5144+ "count" : count ,
5145+ "mean" : mean ,
5146+ "stddev" : stddev ,
5147+ "min" : min_ ,
5148+ "max" : max_ ,
5149+ }
5150+ cols = parse_positional_args_to_list (* cols )
5151+ res_df = self ._calculate_statistics (cols , stat_func_dict )
5152+
51505153 adjust_api_subcalls (
51515154 res_df ,
51525155 "DataFrame.describe" ,
@@ -5159,6 +5162,94 @@ def describe(
51595162
51605163 return res_df
51615164
5165+ @publicapi
5166+ def summary (self , * statistics : str , _emit_ast : bool = True ) -> "DataFrame" :
5167+ """
5168+ Computes specified statistics for all numeric and string columns.
5169+ Non-numeric and non-string columns will be ignored when calling this method.
5170+
5171+ Available statistics are: ``count``, ``mean``, ``stddev``, ``min``, ``max`` and
5172+ arbitrary approximate percentiles specified as a percentage (e.g., 75%).
5173+
5174+ If no statistics are given, this function computes ``count``, ``mean``, ``stddev``, ``min``,
5175+ approximate quartiles (percentiles at 25%, 50%, and 75%), and ``max``.
5176+
5177+ If no columns are provided, this function computes statistics for all numerical or
5178+ string columns.
5179+
5180+ Example::
5181+ >>> df = session.create_dataframe([[1, 2], [3, 4]], schema=["a", "b"])
5182+ >>> desc_result = df.summary().sort("SUMMARY").show()
5183+ -------------------------------------------------------
5184+ |"SUMMARY" |"A" |"B" |
5185+ -------------------------------------------------------
5186+ |25% |1.5 |2.5 |
5187+ |50% |2.0 |3.0 |
5188+ |75% |2.5 |3.5 |
5189+ |count |2.0 |2.0 |
5190+ |max |3.0 |4.0 |
5191+ |mean |2.0 |3.0 |
5192+ |min |1.0 |2.0 |
5193+ |stddev |1.4142135623730951 |1.4142135623730951 |
5194+ -------------------------------------------------------
5195+ <BLANKLINE>
5196+
5197+ Args:
5198+ statistics: The names of columns whose basic statistics are computed.
5199+ """
5200+ # get stats that we want to calculate
5201+ stat_func_dict = {}
5202+ for s in statistics :
5203+ if s .lower () == "count" :
5204+ stat_func_dict [s ] = count
5205+ elif s .lower () == "mean" :
5206+ stat_func_dict [s ] = mean
5207+ elif s .lower () == "stddev" :
5208+ stat_func_dict [s ] = stddev
5209+ elif s .lower () == "min" :
5210+ stat_func_dict [s ] = min_
5211+ elif s .lower () == "max" :
5212+ stat_func_dict [s ] = max_
5213+ elif s .endswith ("%" ):
5214+ try :
5215+ number = float (s [:- 1 ])
5216+ except Exception as ex :
5217+ raise ValueError (f"Unable to parse { s } as a percentile: { ex } ." )
5218+ if number < 0 or number > 100 :
5219+ raise ValueError (
5220+ "requirement failed: Percentiles must be in the range [0, 1]."
5221+ )
5222+ stat_func_dict [s ] = functools .partial (
5223+ approx_percentile , percentile = number / 100
5224+ )
5225+ else :
5226+ raise ValueError (f"{ s } is not a recognised statistic." )
5227+
5228+ # if stats are not specified, use the following default stats
5229+ if not stat_func_dict :
5230+ stat_func_dict = {
5231+ "count" : count ,
5232+ "mean" : mean ,
5233+ "stddev" : stddev ,
5234+ "min" : min_ ,
5235+ "25%" : lambda c : approx_percentile (c , 0.25 ),
5236+ "50%" : lambda c : approx_percentile (c , 0.50 ),
5237+ "75%" : lambda c : approx_percentile (c , 0.75 ),
5238+ "max" : max_ ,
5239+ }
5240+
5241+ # calculate stats on all columns
5242+ res_df = self ._calculate_statistics ([], stat_func_dict )
5243+
5244+ adjust_api_subcalls (
5245+ res_df ,
5246+ "DataFrame.summary" ,
5247+ precalls = self ._plan .api_calls ,
5248+ subcalls = res_df ._plan .api_calls .copy (),
5249+ )
5250+
5251+ return res_df
5252+
51625253 @df_api_usage
51635254 @publicapi
51645255 def rename (
0 commit comments