diff --git a/CHANGELOG.md b/CHANGELOG.md index 32513831bb..19ac3edb5e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -54,6 +54,7 @@ - `log10` - `percentile_approx` - `unbase64` +- Added support for `seed` argument in `DataFrame.stat.sample_by`. Note that it only supports a `Table` object, and will be ignored for a `DataFrame` object. - Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`. - Added support for `DataFrameWriter.insert_into/insertInto`. This method also supports local testing mode. - Added support for multiple columns in the functions `map_cat` and `map_concat`. diff --git a/src/snowflake/snowpark/dataframe_stat_functions.py b/src/snowflake/snowpark/dataframe_stat_functions.py index f7a68fbbec..28dc588cc5 100644 --- a/src/snowflake/snowpark/dataframe_stat_functions.py +++ b/src/snowflake/snowpark/dataframe_stat_functions.py @@ -17,7 +17,7 @@ from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages from snowflake.snowpark._internal.telemetry import adjust_api_subcalls from snowflake.snowpark._internal.type_utils import ColumnOrName, LiteralType -from snowflake.snowpark._internal.utils import publicapi +from snowflake.snowpark._internal.utils import publicapi, warning from snowflake.snowpark.functions import ( _to_col_if_str, approx_percentile_accumulate, @@ -374,6 +374,7 @@ def sample_by( self, col: ColumnOrName, fractions: Dict[LiteralType, float], + seed: Optional[int] = None, _emit_ast: bool = True, ) -> "snowflake.snowpark.DataFrame": """Returns a DataFrame containing a stratified sample without replacement, based on a ``dict`` that specifies the fraction for each stratum. @@ -388,6 +389,9 @@ def sample_by( col: The name of the column that defines the strata. fractions: A ``dict`` that specifies the fraction to use for the sample for each stratum. If a stratum is not specified in the ``dict``, the method uses 0 as the fraction. + seed: Specifies a seed value to make the sampling deterministic. Can be any integer between 0 and 2147483647 inclusive. + Default value is ``None``. This parameter is only supported for :class:`Table`, and it will be ignored + if it is specified for :class`DataFrame`. """ stmt = None @@ -416,15 +420,42 @@ def sample_by( return res_df col = _to_col_if_str(col, "sample_by") - res_df = reduce( - lambda x, y: x.union_all(y, _emit_ast=False), - [ - self._dataframe.filter(col == k, _emit_ast=False).sample( - v, _emit_ast=False + if seed is not None and isinstance(self._dataframe, snowflake.snowpark.Table): + + def equal_condition_str(k: LiteralType) -> str: + return self._dataframe._session._analyzer.binary_operator_extractor( + (col == k)._expression, + df_aliased_col_name_to_real_col_name=self._dataframe._plan.df_aliased_col_name_to_real_col_name, ) - for k, v in fractions.items() - ], - ) + + # Similar to how `Table.sample` is implemented, because SAMPLE clause does not support subqueries, + # we just use session.sql to compile a flat query + res_df = reduce( + lambda x, y: x.union_all(y, _emit_ast=False), + [ + self._dataframe._session.sql( + f"SELECT * FROM {self._dataframe.table_name} SAMPLE ({v * 100.0}) SEED ({seed}) WHERE {equal_condition_str(k)}", + _emit_ast=False, + ) + for k, v in fractions.items() + ], + ) + else: + if seed is not None: + warning( + "stat.sample_by", + "`seed` argument is ignored on `DataFrame` object. Save this DataFrame to a temporary table " + "to get a `Table` object and specify a seed.", + ) + res_df = reduce( + lambda x, y: x.union_all(y, _emit_ast=False), + [ + self._dataframe.filter(col == k, _emit_ast=False).sample( + v, _emit_ast=False + ) + for k, v in fractions.items() + ], + ) adjust_api_subcalls( res_df, "DataFrameStatFunctions.sample_by", diff --git a/tests/integ/scala/test_dataframe_suite.py b/tests/integ/scala/test_dataframe_suite.py index a5dfe9e0c4..08e5677e8a 100644 --- a/tests/integ/scala/test_dataframe_suite.py +++ b/tests/integ/scala/test_dataframe_suite.py @@ -2,6 +2,7 @@ # Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved. # +import logging import copy import math import os @@ -837,6 +838,34 @@ def test_df_stat_sampleBy(session): assert len(sample_by_3.collect()) == 0 +@pytest.mark.skipif( + "config.getoption('local_testing_mode', default=False)", + reason="session.sql is not supported in local testing", +) +def test_df_stat_sampleBy_seed(session, caplog): + temp_table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE) + TestData.monthly_sales(session).write.save_as_table( + temp_table_name, table_type="temp", mode="overwrite" + ) + df = session.table(temp_table_name) + + # with seed, the result is deterministic and should be the same + sample_by_action = ( + lambda df: df.stat.sample_by(col("empid"), {1: 0.5, 2: 0.5}, seed=1) + .sort(df.columns) + .collect() + ) + result = sample_by_action(df) + for _ in range(3): + Utils.check_answer(sample_by_action(df), result) + + # DataFrame doesn't work with seed + caplog.clear() + with caplog.at_level(logging.WARNING): + sample_by_action(TestData.monthly_sales(session)) + assert "`seed` argument is ignored on `DataFrame` object" in caplog.text + + @pytest.mark.skipif( "config.getoption('local_testing_mode', default=False)", reason="FEAT: RelationalGroupedDataFrame.Pivot not supported",