|
78 | 78 | is_bool, |
79 | 79 | is_bool_dtype, |
80 | 80 | is_datetime64_any_dtype, |
| 81 | + is_integer, |
81 | 82 | is_integer_dtype, |
82 | 83 | is_named_tuple, |
83 | 84 | is_numeric_dtype, |
|
151 | 152 | object_keys, |
152 | 153 | pandas_udf, |
153 | 154 | quarter, |
154 | | - random, |
155 | 155 | rank, |
156 | 156 | regexp_replace, |
157 | 157 | reverse, |
@@ -15950,59 +15950,119 @@ def sample( |
15950 | 15950 | ) |
15951 | 15951 |
|
15952 | 15952 | # handle axis = 0 |
| 15953 | + |
15953 | 15954 | if weights is not None: |
15954 | 15955 | ErrorMessage.not_implemented("`weights` is not supported.") |
| 15956 | + if isinstance( |
| 15957 | + random_state, |
| 15958 | + ( |
| 15959 | + np.ndarray, |
| 15960 | + np.random.BitGenerator, |
| 15961 | + np.random.RandomState, |
| 15962 | + np.random.Generator, |
| 15963 | + ), |
| 15964 | + ): |
| 15965 | + ErrorMessage.not_implemented("non-integer `random_state` is not supported.") |
15955 | 15966 |
|
15956 | | - if random_state is not None: |
15957 | | - ErrorMessage.not_implemented("`random_state` is not supported.") |
15958 | | - |
| 15967 | + if random_state is not None and not is_integer(random_state): |
| 15968 | + raise ValueError("random_state must be an integer or None.") |
15959 | 15969 | assert n is not None or frac is not None |
| 15970 | + if not replace: |
| 15971 | + if frac is not None and frac > 1: |
| 15972 | + raise ValueError( |
| 15973 | + "Replace has to be set to `True` when upsampling the population `frac` > 1." |
| 15974 | + ) |
| 15975 | + if n is not None and n > self.get_axis_len(axis=0): |
| 15976 | + raise ValueError( |
| 15977 | + "Cannot take a larger sample than population when 'replace=False'" |
| 15978 | + ) |
| 15979 | + |
15960 | 15980 | frame = self._modin_frame |
15961 | | - if replace: |
15962 | | - sampled_row_position_identifier = ( |
15963 | | - generate_snowflake_quoted_identifiers_helper( |
15964 | | - pandas_labels=[ |
15965 | | - SAMPLED_ROW_POSITION_COLUMN_LABEL, |
15966 | | - ] |
15967 | | - )[0] |
15968 | | - ) |
15969 | 15981 |
|
| 15982 | + # use builtin('random') instead of snowflake.snowpark.functions.random |
| 15983 | + # because the latter does not take Column inputs, but we want to use |
| 15984 | + # pandas_lit() to create the seed. |
| 15985 | + # if random_state is None, we have to call random() with no arguments. |
| 15986 | + # random(NULL) is not valid. |
| 15987 | + random_maybe_with_state = builtin("random")( |
| 15988 | + *(tuple() if random_state is None else (pandas_lit(random_state),)) |
| 15989 | + ) |
| 15990 | + if replace: |
| 15991 | + # If `replace=True`, we can't use snowflake's built-in SAMPLE, which |
| 15992 | + # samples without replacement. |
15970 | 15993 | pre_sampling_rowcount = self.get_axis_len(axis=0) |
15971 | 15994 | if n is not None: |
15972 | 15995 | post_sampling_rowcount = n |
15973 | 15996 | else: |
15974 | 15997 | assert frac is not None |
15975 | 15998 | post_sampling_rowcount = round(frac * pre_sampling_rowcount) |
15976 | 15999 |
|
15977 | | - sampled_row_position_col = uniform( |
15978 | | - 0, pre_sampling_rowcount - 1, random() |
15979 | | - ).as_(sampled_row_position_identifier) |
15980 | | - |
| 16000 | + sampled_row_position_identifier = ( |
| 16001 | + generate_snowflake_quoted_identifiers_helper( |
| 16002 | + pandas_labels=[ |
| 16003 | + SAMPLED_ROW_POSITION_COLUMN_LABEL, |
| 16004 | + ] |
| 16005 | + )[0] |
| 16006 | + ) |
15981 | 16007 | sampled_row_positions_snowpark_frame = pd.session.generator( |
15982 | | - sampled_row_position_col, |
| 16008 | + uniform(0, pre_sampling_rowcount - 1, random_maybe_with_state).as_( |
| 16009 | + sampled_row_position_identifier |
| 16010 | + ), |
15983 | 16011 | rowcount=post_sampling_rowcount, |
15984 | 16012 | ) |
15985 | | - |
15986 | 16013 | sampled_row_positions_odf = OrderedDataFrame( |
15987 | 16014 | dataframe_ref=DataFrameReference(sampled_row_positions_snowpark_frame), |
15988 | 16015 | projected_column_snowflake_quoted_identifiers=[ |
15989 | 16016 | sampled_row_position_identifier |
15990 | 16017 | ], |
15991 | 16018 | ) |
15992 | | - sampled_odf = cache_result( |
15993 | | - sampled_row_positions_odf.join( |
15994 | | - right=self._modin_frame.ordered_dataframe, |
15995 | | - left_on_cols=[sampled_row_position_identifier], |
15996 | | - right_on_cols=[ |
15997 | | - self._modin_frame.ordered_dataframe.row_position_snowflake_quoted_identifier |
15998 | | - ], |
| 16019 | + sampled_odf = sampled_row_positions_odf.join( |
| 16020 | + right=self._modin_frame.ordered_dataframe, |
| 16021 | + left_on_cols=[sampled_row_position_identifier], |
| 16022 | + right_on_cols=[ |
| 16023 | + self._modin_frame.ordered_dataframe.row_position_snowflake_quoted_identifier |
| 16024 | + ], |
| 16025 | + ) |
| 16026 | + # if random_state is not None, the result is already deterministic. |
| 16027 | + if random_state is None: |
| 16028 | + logging.warning( |
| 16029 | + "Snowpark pandas `sample` will create a temp table for " |
| 16030 | + + "sampled results to keep it deterministic." |
| 16031 | + ) |
| 16032 | + sampled_odf = cache_result(sampled_odf) |
| 16033 | + elif random_state is not None: |
| 16034 | + # Snowflake's SAMPLE only accepts a seed when sampling from a |
| 16035 | + # table. A snowflake query compiler does not necessarily correspond |
| 16036 | + # to a particular snowflake table, and even though we could sample |
| 16037 | + # an intermediate table produce with cache_result(), we need to |
| 16038 | + # select a set of rows that is deterministic with respect to the |
| 16039 | + # table length rather than with respect to the query compiler or |
| 16040 | + # even the dataframe. For example, pd.DataFrame(list(range(1000))).sample(n=1, random_state=0) and |
| 16041 | + # pd.DataFrame(list(range(1000))[::-1]).sample(n=1, random_state=0) |
| 16042 | + # select the same row position. |
| 16043 | + if n is not None: |
| 16044 | + post_sampling_rowcount = n |
| 16045 | + else: |
| 16046 | + assert frac is not None |
| 16047 | + pre_sampling_rowcount = self.get_axis_len(axis=0) |
| 16048 | + post_sampling_rowcount = round(frac * pre_sampling_rowcount) |
| 16049 | + # Choose the top `post_sampling_rowcount` rows according to a random |
| 16050 | + # order. |
| 16051 | + new_identifier = self._modin_frame.ordered_dataframe.generate_snowflake_quoted_identifiers( |
| 16052 | + pandas_labels=["random_row_position"] |
| 16053 | + )[ |
| 16054 | + 0 |
| 16055 | + ] |
| 16056 | + sampled_odf = ( |
| 16057 | + self._modin_frame.ordered_dataframe.select( |
| 16058 | + *self._modin_frame.ordered_dataframe.projected_column_snowflake_quoted_identifiers, |
| 16059 | + random_maybe_with_state.as_(new_identifier), |
15999 | 16060 | ) |
| 16061 | + .sort(OrderingColumn(new_identifier)) |
| 16062 | + .limit(post_sampling_rowcount) |
16000 | 16063 | ) |
16001 | 16064 | else: |
16002 | 16065 | sampled_odf = frame.ordered_dataframe.sample(n=n, frac=frac) |
16003 | | - logging.warning( |
16004 | | - "Snowpark pandas `sample` will create a temp table for sampled results to keep it deterministic." |
16005 | | - ) |
16006 | 16066 | res = SnowflakeQueryCompiler( |
16007 | 16067 | InternalFrame.create( |
16008 | 16068 | ordered_dataframe=sampled_odf, |
|
0 commit comments