Skip to content

Commit dfbb0eb

Browse files
authored
SNOW-1885815: Add support for seed argument in DataFrame.stat.sample_by (#2925)
1 parent e8a6756 commit dfbb0eb

File tree

3 files changed

+70
-9
lines changed

3 files changed

+70
-9
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
- `log10`
5555
- `percentile_approx`
5656
- `unbase64`
57+
- Added support for `seed` argument in `DataFrame.stat.sample_by`. Note that it only supports a `Table` object, and will be ignored for a `DataFrame` object.
5758
- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.
5859
- Added support for `DataFrameWriter.insert_into/insertInto`. This method also supports local testing mode.
5960
- Added support for `DataFrame.create_temp_view` to create a temporary view. It will fail if the view already exists.

src/snowflake/snowpark/dataframe_stat_functions.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
1818
from snowflake.snowpark._internal.telemetry import adjust_api_subcalls
1919
from snowflake.snowpark._internal.type_utils import ColumnOrName, LiteralType
20-
from snowflake.snowpark._internal.utils import publicapi
20+
from snowflake.snowpark._internal.utils import publicapi, warning
2121
from snowflake.snowpark.functions import (
2222
_to_col_if_str,
2323
approx_percentile_accumulate,
@@ -374,6 +374,7 @@ def sample_by(
374374
self,
375375
col: ColumnOrName,
376376
fractions: Dict[LiteralType, float],
377+
seed: Optional[int] = None,
377378
_emit_ast: bool = True,
378379
) -> "snowflake.snowpark.DataFrame":
379380
"""Returns a DataFrame containing a stratified sample without replacement, based on a ``dict`` that specifies the fraction for each stratum.
@@ -388,6 +389,9 @@ def sample_by(
388389
col: The name of the column that defines the strata.
389390
fractions: A ``dict`` that specifies the fraction to use for the sample for each stratum.
390391
If a stratum is not specified in the ``dict``, the method uses 0 as the fraction.
392+
seed: Specifies a seed value to make the sampling deterministic. Can be any integer between 0 and 2147483647 inclusive.
393+
Default value is ``None``. This parameter is only supported for :class:`Table`, and it will be ignored
394+
if it is specified for :class`DataFrame`.
391395
"""
392396

393397
stmt = None
@@ -416,15 +420,42 @@ def sample_by(
416420
return res_df
417421

418422
col = _to_col_if_str(col, "sample_by")
419-
res_df = reduce(
420-
lambda x, y: x.union_all(y, _emit_ast=False),
421-
[
422-
self._dataframe.filter(col == k, _emit_ast=False).sample(
423-
v, _emit_ast=False
423+
if seed is not None and isinstance(self._dataframe, snowflake.snowpark.Table):
424+
425+
def equal_condition_str(k: LiteralType) -> str:
426+
return self._dataframe._session._analyzer.binary_operator_extractor(
427+
(col == k)._expression,
428+
df_aliased_col_name_to_real_col_name=self._dataframe._plan.df_aliased_col_name_to_real_col_name,
424429
)
425-
for k, v in fractions.items()
426-
],
427-
)
430+
431+
# Similar to how `Table.sample` is implemented, because SAMPLE clause does not support subqueries,
432+
# we just use session.sql to compile a flat query
433+
res_df = reduce(
434+
lambda x, y: x.union_all(y, _emit_ast=False),
435+
[
436+
self._dataframe._session.sql(
437+
f"SELECT * FROM {self._dataframe.table_name} SAMPLE ({v * 100.0}) SEED ({seed}) WHERE {equal_condition_str(k)}",
438+
_emit_ast=False,
439+
)
440+
for k, v in fractions.items()
441+
],
442+
)
443+
else:
444+
if seed is not None:
445+
warning(
446+
"stat.sample_by",
447+
"`seed` argument is ignored on `DataFrame` object. Save this DataFrame to a temporary table "
448+
"to get a `Table` object and specify a seed.",
449+
)
450+
res_df = reduce(
451+
lambda x, y: x.union_all(y, _emit_ast=False),
452+
[
453+
self._dataframe.filter(col == k, _emit_ast=False).sample(
454+
v, _emit_ast=False
455+
)
456+
for k, v in fractions.items()
457+
],
458+
)
428459
adjust_api_subcalls(
429460
res_df,
430461
"DataFrameStatFunctions.sample_by",

tests/integ/scala/test_dataframe_suite.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
33
#
44

5+
import logging
56
import copy
67
import math
78
import os
@@ -837,6 +838,34 @@ def test_df_stat_sampleBy(session):
837838
assert len(sample_by_3.collect()) == 0
838839

839840

841+
@pytest.mark.skipif(
842+
"config.getoption('local_testing_mode', default=False)",
843+
reason="session.sql is not supported in local testing",
844+
)
845+
def test_df_stat_sampleBy_seed(session, caplog):
846+
temp_table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
847+
TestData.monthly_sales(session).write.save_as_table(
848+
temp_table_name, table_type="temp", mode="overwrite"
849+
)
850+
df = session.table(temp_table_name)
851+
852+
# with seed, the result is deterministic and should be the same
853+
sample_by_action = (
854+
lambda df: df.stat.sample_by(col("empid"), {1: 0.5, 2: 0.5}, seed=1)
855+
.sort(df.columns)
856+
.collect()
857+
)
858+
result = sample_by_action(df)
859+
for _ in range(3):
860+
Utils.check_answer(sample_by_action(df), result)
861+
862+
# DataFrame doesn't work with seed
863+
caplog.clear()
864+
with caplog.at_level(logging.WARNING):
865+
sample_by_action(TestData.monthly_sales(session))
866+
assert "`seed` argument is ignored on `DataFrame` object" in caplog.text
867+
868+
840869
@pytest.mark.skipif(
841870
"config.getoption('local_testing_mode', default=False)",
842871
reason="FEAT: RelationalGroupedDataFrame.Pivot not supported",

0 commit comments

Comments
 (0)