Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
- `try_to_binary`
- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.
- Added support for `DataFrameWriter.insert_into/insertInto`. This method also supports local testing mode.
- Added support for `seed` argument in `DataFrame.stat.sample_by`. Note that it only supports a `Table` object, and will be ignored for a `DataFrame` object.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about logging a warning when the DataFrame isn't a Table?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it's already added


#### Experimental Features

Expand Down
49 changes: 40 additions & 9 deletions src/snowflake/snowpark/dataframe_stat_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from snowflake.snowpark._internal.error_message import SnowparkClientExceptionMessages
from snowflake.snowpark._internal.telemetry import adjust_api_subcalls
from snowflake.snowpark._internal.type_utils import ColumnOrName, LiteralType
from snowflake.snowpark._internal.utils import publicapi
from snowflake.snowpark._internal.utils import publicapi, warning
from snowflake.snowpark.functions import (
_to_col_if_str,
approx_percentile_accumulate,
Expand Down Expand Up @@ -374,6 +374,7 @@ def sample_by(
self,
col: ColumnOrName,
fractions: Dict[LiteralType, float],
seed: Optional[int] = None,
_emit_ast: bool = True,
) -> "snowflake.snowpark.DataFrame":
"""Returns a DataFrame containing a stratified sample without replacement, based on a ``dict`` that specifies the fraction for each stratum.
Expand All @@ -388,6 +389,9 @@ def sample_by(
col: The name of the column that defines the strata.
fractions: A ``dict`` that specifies the fraction to use for the sample for each stratum.
If a stratum is not specified in the ``dict``, the method uses 0 as the fraction.
seed: Specifies a seed value to make the sampling deterministic. Can be any integer between 0 and 2147483647 inclusive.
Copy link
Contributor

@sfc-gh-aling sfc-gh-aling Jan 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what would happen if seed is not in the accepted range?
a quick follow-up is do we need client validation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will raise a SQL error. I think we don't need to do it right now, as DataFrame.sample also doesn't do it.

Default value is ``None``. This parameter is only supported for :class:`Table`, and it will be ignored
if it is specified for :class`DataFrame`.
Comment on lines +393 to +394
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious to know why we choose to ignore instead of failing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For SAS team to work around by creating a temp table first. See https://snowflakecomputing.atlassian.net/browse/SNOW-1894684

"""

stmt = None
Expand Down Expand Up @@ -416,15 +420,42 @@ def sample_by(
return res_df

col = _to_col_if_str(col, "sample_by")
res_df = reduce(
lambda x, y: x.union_all(y, _emit_ast=False),
[
self._dataframe.filter(col == k, _emit_ast=False).sample(
v, _emit_ast=False
if seed is not None and isinstance(self._dataframe, snowflake.snowpark.Table):

def equal_condition_str(k: LiteralType) -> str:
return self._dataframe._session._analyzer.binary_operator_extractor(
(col == k)._expression,
df_aliased_col_name_to_real_col_name=self._dataframe._plan.df_aliased_col_name_to_real_col_name,
)
for k, v in fractions.items()
],
)

# Similar to how `Table.sample` is implemented, because SAMPLE clause does not support subqueries,
# we just use session.sql to compile a flat query
res_df = reduce(
lambda x, y: x.union_all(y, _emit_ast=False),
[
self._dataframe._session.sql(
f"SELECT * FROM {self._dataframe.table_name} SAMPLE ({v * 100.0}) SEED ({seed}) WHERE {equal_condition_str(k)}",
_emit_ast=False,
)
for k, v in fractions.items()
],
)
else:
if seed is not None:
warning(
"stat.sample_by",
"`seed` argument is ignored on `DataFrame` object. Save this DataFrame to a temporary table "
"to get a `Table` object and specify a seed.",
)
res_df = reduce(
lambda x, y: x.union_all(y, _emit_ast=False),
[
self._dataframe.filter(col == k, _emit_ast=False).sample(
v, _emit_ast=False
)
for k, v in fractions.items()
],
)
adjust_api_subcalls(
res_df,
"DataFrameStatFunctions.sample_by",
Expand Down
29 changes: 29 additions & 0 deletions tests/integ/scala/test_dataframe_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
#

import logging
import copy
import math
import os
Expand Down Expand Up @@ -837,6 +838,34 @@ def test_df_stat_sampleBy(session):
assert len(sample_by_3.collect()) == 0


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="session.sql is not supported in local testing",
)
def test_df_stat_sampleBy_seed(session, caplog):
temp_table_name = Utils.random_name_for_temp_object(TempObjectType.TABLE)
TestData.monthly_sales(session).write.save_as_table(
temp_table_name, table_type="temp", mode="overwrite"
)
df = session.table(temp_table_name)

# with seed, the result is deterministic and should be the same
sample_by_action = (
lambda df: df.stat.sample_by(col("empid"), {1: 0.5, 2: 0.5}, seed=1)
.sort(df.columns)
.collect()
)
result = sample_by_action(df)
for _ in range(3):
Utils.check_answer(sample_by_action(df), result)

# DataFrame doesn't work with seed
caplog.clear()
with caplog.at_level(logging.WARNING):
sample_by_action(TestData.monthly_sales(session))
assert "`seed` argument is ignored on `DataFrame` object" in caplog.text


@pytest.mark.skipif(
"config.getoption('local_testing_mode', default=False)",
reason="FEAT: RelationalGroupedDataFrame.Pivot not supported",
Expand Down
Loading