-
Notifications
You must be signed in to change notification settings - Fork 144
SNOW-1885815: Add support for seed argument in DataFrame.stat.sample_by
#2925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,7 +17,7 @@ | |
| from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages | ||
| from snowflake.snowpark._internal.telemetry import adjust_api_subcalls | ||
| from snowflake.snowpark._internal.type_utils import ColumnOrName, LiteralType | ||
| from snowflake.snowpark._internal.utils import publicapi | ||
| from snowflake.snowpark._internal.utils import publicapi, warning | ||
| from snowflake.snowpark.functions import ( | ||
| _to_col_if_str, | ||
| approx_percentile_accumulate, | ||
|
|
@@ -374,6 +374,7 @@ def sample_by( | |
| self, | ||
| col: ColumnOrName, | ||
| fractions: Dict[LiteralType, float], | ||
| seed: Optional[int] = None, | ||
| _emit_ast: bool = True, | ||
| ) -> "snowflake.snowpark.DataFrame": | ||
| """Returns a DataFrame containing a stratified sample without replacement, based on a ``dict`` that specifies the fraction for each stratum. | ||
|
|
@@ -388,6 +389,9 @@ def sample_by( | |
| col: The name of the column that defines the strata. | ||
| fractions: A ``dict`` that specifies the fraction to use for the sample for each stratum. | ||
| If a stratum is not specified in the ``dict``, the method uses 0 as the fraction. | ||
| seed: Specifies a seed value to make the sampling deterministic. Can be any integer between 0 and 2147483647 inclusive. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what would happen if seed is not in the accepted range?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It will raise a SQL error. I think we don't need to do it right now, as |
||
| Default value is ``None``. This parameter is only supported for :class:`Table`, and it will be ignored | ||
| if it is specified for :class`DataFrame`. | ||
|
Comment on lines
+393
to
+394
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious to know why we choose to ignore instead of failing?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For SAS team to work around by creating a temp table first. See https://snowflakecomputing.atlassian.net/browse/SNOW-1894684 |
||
| """ | ||
|
|
||
| stmt = None | ||
|
|
@@ -416,15 +420,42 @@ def sample_by( | |
| return res_df | ||
|
|
||
| col = _to_col_if_str(col, "sample_by") | ||
| res_df = reduce( | ||
| lambda x, y: x.union_all(y, _emit_ast=False), | ||
| [ | ||
| self._dataframe.filter(col == k, _emit_ast=False).sample( | ||
| v, _emit_ast=False | ||
| if seed is not None and isinstance(self._dataframe, snowflake.snowpark.Table): | ||
|
|
||
| def equal_condition_str(k: LiteralType) -> str: | ||
| return self._dataframe._session._analyzer.binary_operator_extractor( | ||
| (col == k)._expression, | ||
| df_aliased_col_name_to_real_col_name=self._dataframe._plan.df_aliased_col_name_to_real_col_name, | ||
| ) | ||
| for k, v in fractions.items() | ||
| ], | ||
| ) | ||
|
|
||
| # Similar to how `Table.sample` is implemented, because SAMPLE clause does not support subqueries, | ||
| # we just use session.sql to compile a flat query | ||
| res_df = reduce( | ||
| lambda x, y: x.union_all(y, _emit_ast=False), | ||
| [ | ||
| self._dataframe._session.sql( | ||
| f"SELECT * FROM {self._dataframe.table_name} SAMPLE ({v * 100.0}) SEED ({seed}) WHERE {equal_condition_str(k)}", | ||
| _emit_ast=False, | ||
| ) | ||
| for k, v in fractions.items() | ||
| ], | ||
| ) | ||
| else: | ||
| if seed is not None: | ||
| warning( | ||
| "stat.sample_by", | ||
| "`seed` argument is ignored on `DataFrame` object. Save this DataFrame to a temporary table " | ||
| "to get a `Table` object and specify a seed.", | ||
| ) | ||
| res_df = reduce( | ||
| lambda x, y: x.union_all(y, _emit_ast=False), | ||
| [ | ||
| self._dataframe.filter(col == k, _emit_ast=False).sample( | ||
| v, _emit_ast=False | ||
| ) | ||
| for k, v in fractions.items() | ||
| ], | ||
| ) | ||
| adjust_api_subcalls( | ||
| res_df, | ||
| "DataFrameStatFunctions.sample_by", | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about logging a warning when the DataFrame isn't a Table?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea it's already added