Skip to content

Commit 51003a5

Browse files
SNOW-2396205: Support random_state in sample.
Signed-off-by: sfc-gh-mvashishtha <[email protected]>
1 parent aad530a commit 51003a5

File tree

4 files changed

+328
-50
lines changed

4 files changed

+328
-50
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575

7676
#### New Features
7777
- Added support for the `dtypes` parameter of `pd.get_dummies`
78+
- Added support for the `random_state` parameter of `DataFrame.sample` and `Series.sample`
7879

7980
#### Improvements
8081

@@ -148,6 +149,7 @@
148149
#### Bug Fixes
149150

150151
- Fixed a bug where the row count was not getting cached in the ordered dataframe each time count_rows() is called.
152+
- Fixed bug where `Series.sample` and `DataFrame.sample` would allow setting `n` larger than the number of rows while `replace=False`.
151153

152154
## 1.40.0 (2025-10-02)
153155

src/snowflake/snowpark/modin/plugin/compiler/snowflake_query_compiler.py

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
is_bool,
7979
is_bool_dtype,
8080
is_datetime64_any_dtype,
81+
is_integer,
8182
is_integer_dtype,
8283
is_named_tuple,
8384
is_numeric_dtype,
@@ -151,7 +152,6 @@
151152
object_keys,
152153
pandas_udf,
153154
quarter,
154-
random,
155155
rank,
156156
regexp_replace,
157157
reverse,
@@ -15950,59 +15950,119 @@ def sample(
1595015950
)
1595115951

1595215952
# handle axis = 0
15953+
1595315954
if weights is not None:
1595415955
ErrorMessage.not_implemented("`weights` is not supported.")
15956+
if isinstance(
15957+
random_state,
15958+
(
15959+
np.ndarray,
15960+
np.random.BitGenerator,
15961+
np.random.RandomState,
15962+
np.random.Generator,
15963+
),
15964+
):
15965+
ErrorMessage.not_implemented("non-integer `random_state` is not supported.")
1595515966

15956-
if random_state is not None:
15957-
ErrorMessage.not_implemented("`random_state` is not supported.")
15958-
15967+
if random_state is not None and not is_integer(random_state):
15968+
raise ValueError("random_state must be an integer or None.")
1595915969
assert n is not None or frac is not None
15970+
if not replace:
15971+
if frac is not None and frac > 1:
15972+
raise ValueError(
15973+
"Replace has to be set to `True` when upsampling the population `frac` > 1."
15974+
)
15975+
if n is not None and n > self.get_axis_len(axis=0):
15976+
raise ValueError(
15977+
"Cannot take a larger sample than population when 'replace=False'"
15978+
)
15979+
1596015980
frame = self._modin_frame
15961-
if replace:
15962-
sampled_row_position_identifier = (
15963-
generate_snowflake_quoted_identifiers_helper(
15964-
pandas_labels=[
15965-
SAMPLED_ROW_POSITION_COLUMN_LABEL,
15966-
]
15967-
)[0]
15968-
)
1596915981

15982+
# use builtin('random') instead of snowflake.snowpark.functions.random
15983+
# because the latter does not take Column inputs, but we want to use
15984+
# pandas_lit() to create the seed.
15985+
# if random_state is None, we have to call random() with no arguments.
15986+
# random(NULL) is not valid.
15987+
random_maybe_with_state = builtin("random")(
15988+
*(tuple() if random_state is None else (pandas_lit(random_state),))
15989+
)
15990+
if replace:
15991+
# If `replace=True`, we can't use snowflake's built-in SAMPLE, which
15992+
# samples without replacement.
1597015993
pre_sampling_rowcount = self.get_axis_len(axis=0)
1597115994
if n is not None:
1597215995
post_sampling_rowcount = n
1597315996
else:
1597415997
assert frac is not None
1597515998
post_sampling_rowcount = round(frac * pre_sampling_rowcount)
1597615999

15977-
sampled_row_position_col = uniform(
15978-
0, pre_sampling_rowcount - 1, random()
15979-
).as_(sampled_row_position_identifier)
15980-
16000+
sampled_row_position_identifier = (
16001+
generate_snowflake_quoted_identifiers_helper(
16002+
pandas_labels=[
16003+
SAMPLED_ROW_POSITION_COLUMN_LABEL,
16004+
]
16005+
)[0]
16006+
)
1598116007
sampled_row_positions_snowpark_frame = pd.session.generator(
15982-
sampled_row_position_col,
16008+
uniform(0, pre_sampling_rowcount - 1, random_maybe_with_state).as_(
16009+
sampled_row_position_identifier
16010+
),
1598316011
rowcount=post_sampling_rowcount,
1598416012
)
15985-
1598616013
sampled_row_positions_odf = OrderedDataFrame(
1598716014
dataframe_ref=DataFrameReference(sampled_row_positions_snowpark_frame),
1598816015
projected_column_snowflake_quoted_identifiers=[
1598916016
sampled_row_position_identifier
1599016017
],
1599116018
)
15992-
sampled_odf = cache_result(
15993-
sampled_row_positions_odf.join(
15994-
right=self._modin_frame.ordered_dataframe,
15995-
left_on_cols=[sampled_row_position_identifier],
15996-
right_on_cols=[
15997-
self._modin_frame.ordered_dataframe.row_position_snowflake_quoted_identifier
15998-
],
16019+
sampled_odf = sampled_row_positions_odf.join(
16020+
right=self._modin_frame.ordered_dataframe,
16021+
left_on_cols=[sampled_row_position_identifier],
16022+
right_on_cols=[
16023+
self._modin_frame.ordered_dataframe.row_position_snowflake_quoted_identifier
16024+
],
16025+
)
16026+
# if random_state is not None, the result is already deterministic.
16027+
if random_state is None:
16028+
logging.warning(
16029+
"Snowpark pandas `sample` will create a temp table for "
16030+
+ "sampled results to keep it deterministic."
16031+
)
16032+
sampled_odf = cache_result(sampled_odf)
16033+
elif random_state is not None:
16034+
# Snowflake's SAMPLE only accepts a seed when sampling from a
16035+
# table. A snowflake query compiler does not necessarily correspond
16036+
# to a particular snowflake table, and even though we could sample
16037+
# an intermediate table produce with cache_result(), we need to
16038+
# select a set of rows that is deterministic with respect to the
16039+
# table length rather than with respect to the query compiler or
16040+
# even the dataframe. For example, pd.DataFrame(list(range(1000))).sample(n=1, random_state=0) and
16041+
# pd.DataFrame(list(range(1000))[::-1]).sample(n=1, random_state=0)
16042+
# select the same row position.
16043+
if n is not None:
16044+
post_sampling_rowcount = n
16045+
else:
16046+
assert frac is not None
16047+
pre_sampling_rowcount = self.get_axis_len(axis=0)
16048+
post_sampling_rowcount = round(frac * pre_sampling_rowcount)
16049+
# Choose the top `post_sampling_rowcount` rows according to a random
16050+
# order.
16051+
new_identifier = self._modin_frame.ordered_dataframe.generate_snowflake_quoted_identifiers(
16052+
pandas_labels=["random_row_position"]
16053+
)[
16054+
0
16055+
]
16056+
sampled_odf = (
16057+
self._modin_frame.ordered_dataframe.select(
16058+
*self._modin_frame.ordered_dataframe.projected_column_snowflake_quoted_identifiers,
16059+
random_maybe_with_state.as_(new_identifier),
1599916060
)
16061+
.sort(OrderingColumn(new_identifier))
16062+
.limit(post_sampling_rowcount)
1600016063
)
1600116064
else:
1600216065
sampled_odf = frame.ordered_dataframe.sample(n=n, frac=frac)
16003-
logging.warning(
16004-
"Snowpark pandas `sample` will create a temp table for sampled results to keep it deterministic."
16005-
)
1600616066
res = SnowflakeQueryCompiler(
1600716067
InternalFrame.create(
1600816068
ordered_dataframe=sampled_odf,

0 commit comments

Comments
 (0)