Skip to content

Commit e0229ec

Browse files
sfc-gh-mvashishthasfc-gh-joshigraphite-app[bot]
authored
SNOW-2396205: Support random_state in sample. (#3918)
Implement the random_state parameter DataFrame.sample and Series.sample by sorting with a seeded random order and then selecting the top `n` or `frac * len(dataset)` rows. We use this solution because we can't use the built-in SAMPLE with SEED for this use case. Signed-off-by: sfc-gh-mvashishtha <[email protected]> Co-authored-by: Jonathan Shi <[email protected]> Co-authored-by: graphite-app[bot] <96075541+graphite-app[bot]@users.noreply.github.com>
1 parent b7ad653 commit e0229ec

File tree

6 files changed

+322
-43
lines changed

6 files changed

+322
-43
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
- Added support for `Dataframe.groupby.rolling()`.
2424
- Added support for mapping `np.percentile` with DataFrame and Series inputs to `Series.quantile`.
25+
- Added support for setting the `random_state` parameter to an integer when calling `DataFrame.sample` or `Series.sample`.
2526

2627
#### Bug Fixes
2728

docs/source/modin/supported/dataframe_supported.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,8 +384,13 @@ Methods
384384
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
385385
| ``rtruediv`` | P | ``level`` | |
386386
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
387-
| ``sample`` | P | | ``N`` if ``weights`` or ``random_state`` is |
388-
| | | | specified when ``axis = 0`` |
387+
| ``sample`` | P | | ``N`` if ``weights`` is specified when |
388+
| | | | ``axis = 0``, or if ``random_state`` is not |
389+
| | | | either an integer or ``None``. Setting |
390+
| | | | ``random_state`` to a value other than ``None`` |
391+
| | | | may slow down this method because the ``sample`` |
392+
| | | | implementation will use a sort instead of the |
393+
| | | | Snowflake warehouse's built-in SAMPLE construct. |
389394
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
390395
| ``select_dtypes`` | Y | | |
391396
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+

docs/source/modin/supported/series_supported.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,8 +383,13 @@ Methods
383383
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
384384
| ``rtruediv`` | P | ``level`` | See ``truediv`` |
385385
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
386-
| ``sample`` | P | | ``N`` if ``weights`` or ``random_state`` is |
387-
| | | | specified when ``axis = 0`` |
386+
| ``sample`` | P | | ``N`` if ``weights`` is specified when |
387+
| | | | ``axis = 0``, or if ``random_state`` is not |
388+
| | | | either an integer or ``None``. Setting |
389+
| | | | ``random_state`` to a value other than ``None`` |
390+
| | | | may slow down this method because the ``sample`` |
391+
| | | | implementation will use a sort instead of the |
392+
| | | | Snowflake warehouse's built-in SAMPLE construct. |
388393
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+
389394
| ``searchsorted`` | N | | |
390395
+-----------------------------+---------------------------------+----------------------------------+----------------------------------------------------+

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
is_bool,
7979
is_bool_dtype,
8080
is_datetime64_any_dtype,
81+
is_integer,
8182
is_integer_dtype,
8283
is_named_tuple,
8384
is_numeric_dtype,
@@ -151,7 +152,6 @@
151152
object_keys,
152153
pandas_udf,
153154
quarter,
154-
random,
155155
rank,
156156
regexp_replace,
157157
reverse,
@@ -16422,59 +16422,120 @@ def sample(
1642216422
)
1642316423

1642416424
# handle axis = 0
16425+
1642516426
if weights is not None:
1642616427
ErrorMessage.not_implemented("`weights` is not supported.")
16428+
if isinstance(
16429+
random_state,
16430+
(
16431+
np.ndarray,
16432+
np.random.BitGenerator,
16433+
np.random.RandomState,
16434+
np.random.Generator,
16435+
),
16436+
):
16437+
ErrorMessage.not_implemented("non-integer `random_state` is not supported.")
1642716438

16428-
if random_state is not None:
16429-
ErrorMessage.not_implemented("`random_state` is not supported.")
16430-
16439+
if random_state is not None and not is_integer(random_state):
16440+
raise ValueError("random_state must be an integer or None.")
1643116441
assert n is not None or frac is not None
16432-
frame = self._modin_frame
16433-
if replace:
16434-
sampled_row_position_identifier = (
16435-
generate_snowflake_quoted_identifiers_helper(
16436-
pandas_labels=[
16437-
SAMPLED_ROW_POSITION_COLUMN_LABEL,
16438-
]
16439-
)[0]
16442+
if not replace and frac is not None and frac > 1:
16443+
raise ValueError(
16444+
"Replace has to be set to `True` when upsampling the population `frac` > 1."
1644016445
)
1644116446

16447+
frame = self._modin_frame
16448+
16449+
# use builtin('random') instead of snowflake.snowpark.functions.random
16450+
# because the latter does not take Column inputs, but we want to use
16451+
# pandas_lit() to create the seed.
16452+
# if random_state is None, we have to call random() with no arguments.
16453+
# random(NULL) is not valid.
16454+
builtin_random = builtin("random")
16455+
random_column = (
16456+
builtin_random()
16457+
if random_state is None
16458+
else builtin_random(pandas_lit(random_state))
16459+
)
16460+
if replace:
16461+
# If `replace=True`, we can't use snowflake's built-in SAMPLE, which
16462+
# samples without replacement.
1644216463
pre_sampling_rowcount = self.get_axis_len(axis=0)
1644316464
if n is not None:
1644416465
post_sampling_rowcount = n
1644516466
else:
1644616467
assert frac is not None
1644716468
post_sampling_rowcount = round(frac * pre_sampling_rowcount)
1644816469

16449-
sampled_row_position_col = uniform(
16450-
0, pre_sampling_rowcount - 1, random()
16451-
).as_(sampled_row_position_identifier)
16452-
16470+
sampled_row_position_identifier = (
16471+
generate_snowflake_quoted_identifiers_helper(
16472+
pandas_labels=[
16473+
SAMPLED_ROW_POSITION_COLUMN_LABEL,
16474+
]
16475+
)[0]
16476+
)
1645316477
sampled_row_positions_snowpark_frame = pd.session.generator(
16454-
sampled_row_position_col,
16478+
uniform(0, pre_sampling_rowcount - 1, random_column).as_(
16479+
sampled_row_position_identifier
16480+
),
1645516481
rowcount=post_sampling_rowcount,
1645616482
)
16457-
1645816483
sampled_row_positions_odf = OrderedDataFrame(
1645916484
dataframe_ref=DataFrameReference(sampled_row_positions_snowpark_frame),
1646016485
projected_column_snowflake_quoted_identifiers=[
1646116486
sampled_row_position_identifier
1646216487
],
1646316488
)
16464-
sampled_odf = cache_result(
16465-
sampled_row_positions_odf.join(
16466-
right=self._modin_frame.ordered_dataframe,
16467-
left_on_cols=[sampled_row_position_identifier],
16468-
right_on_cols=[
16469-
self._modin_frame.ordered_dataframe.row_position_snowflake_quoted_identifier
16470-
],
16489+
sampled_odf = sampled_row_positions_odf.join(
16490+
right=self._modin_frame.ordered_dataframe,
16491+
left_on_cols=[sampled_row_position_identifier],
16492+
right_on_cols=[
16493+
self._modin_frame.ordered_dataframe.row_position_snowflake_quoted_identifier
16494+
],
16495+
)
16496+
# if random_state is not None, the result is seeded and already deterministic.
16497+
if random_state is None:
16498+
logging.warning(
16499+
"Snowpark pandas `sample` will create a temp table for "
16500+
+ "sampled results to keep it deterministic."
16501+
)
16502+
sampled_odf = cache_result(sampled_odf)
16503+
elif random_state is not None:
16504+
# Snowflake's SAMPLE, while more performant than this appraoch,
16505+
# only accepts a seed when sampling from a table. A snowflake query
16506+
# compiler does not necessarily correspond to a particular snowflake
16507+
# table, and even though we could sample an intermediate table
16508+
# produced with cache_result(), we need to select a set of rows that
16509+
# is deterministic with respect to the table length rather than with
16510+
# respect to the query compiler or even the dataframe. For example,
16511+
# pd.DataFrame(list(range(1000))).sample(n=1, random_state=0) and
16512+
# pd.DataFrame(list(range(1000))[::-1]).sample(n=1, random_state=0)
16513+
# select the same row position.
16514+
# We use this alternate implementation rather than the generator one
16515+
# that we use for replace=True because we can avoid a join.
16516+
if n is not None:
16517+
post_sampling_rowcount = n
16518+
else:
16519+
assert frac is not None
16520+
pre_sampling_rowcount = self.get_axis_len(axis=0)
16521+
post_sampling_rowcount = round(frac * pre_sampling_rowcount)
16522+
# Choose the top `post_sampling_rowcount` rows according to a random
16523+
# order.
16524+
new_identifier = self._modin_frame.ordered_dataframe.generate_snowflake_quoted_identifiers(
16525+
pandas_labels=["random_row_position"]
16526+
)[
16527+
0
16528+
]
16529+
sampled_odf = (
16530+
self._modin_frame.ordered_dataframe.select(
16531+
*self._modin_frame.ordered_dataframe.projected_column_snowflake_quoted_identifiers,
16532+
random_column.as_(new_identifier),
1647116533
)
16534+
.sort(OrderingColumn(new_identifier))
16535+
.limit(post_sampling_rowcount)
1647216536
)
1647316537
else:
1647416538
sampled_odf = frame.ordered_dataframe.sample(n=n, frac=frac)
16475-
logging.warning(
16476-
"Snowpark pandas `sample` will create a temp table for sampled results to keep it deterministic."
16477-
)
1647816539
res = SnowflakeQueryCompiler(
1647916540
InternalFrame.create(
1648016541
ordered_dataframe=sampled_odf,

0 commit comments

Comments
 (0)