|
78 | 78 | is_bool, |
79 | 79 | is_bool_dtype, |
80 | 80 | is_datetime64_any_dtype, |
| 81 | + is_integer, |
81 | 82 | is_integer_dtype, |
82 | 83 | is_named_tuple, |
83 | 84 | is_numeric_dtype, |
|
151 | 152 | object_keys, |
152 | 153 | pandas_udf, |
153 | 154 | quarter, |
154 | | - random, |
155 | 155 | rank, |
156 | 156 | regexp_replace, |
157 | 157 | reverse, |
@@ -16422,59 +16422,120 @@ def sample( |
16422 | 16422 | ) |
16423 | 16423 |
|
16424 | 16424 | # handle axis = 0 |
| 16425 | + |
16425 | 16426 | if weights is not None: |
16426 | 16427 | ErrorMessage.not_implemented("`weights` is not supported.") |
| 16428 | + if isinstance( |
| 16429 | + random_state, |
| 16430 | + ( |
| 16431 | + np.ndarray, |
| 16432 | + np.random.BitGenerator, |
| 16433 | + np.random.RandomState, |
| 16434 | + np.random.Generator, |
| 16435 | + ), |
| 16436 | + ): |
| 16437 | + ErrorMessage.not_implemented("non-integer `random_state` is not supported.") |
16427 | 16438 |
|
16428 | | - if random_state is not None: |
16429 | | - ErrorMessage.not_implemented("`random_state` is not supported.") |
16430 | | - |
| 16439 | + if random_state is not None and not is_integer(random_state): |
| 16440 | + raise ValueError("random_state must be an integer or None.") |
16431 | 16441 | assert n is not None or frac is not None |
16432 | | - frame = self._modin_frame |
16433 | | - if replace: |
16434 | | - sampled_row_position_identifier = ( |
16435 | | - generate_snowflake_quoted_identifiers_helper( |
16436 | | - pandas_labels=[ |
16437 | | - SAMPLED_ROW_POSITION_COLUMN_LABEL, |
16438 | | - ] |
16439 | | - )[0] |
| 16442 | + if not replace and frac is not None and frac > 1: |
| 16443 | + raise ValueError( |
| 16444 | + "Replace has to be set to `True` when upsampling the population `frac` > 1." |
16440 | 16445 | ) |
16441 | 16446 |
|
| 16447 | + frame = self._modin_frame |
| 16448 | + |
| 16449 | + # use builtin('random') instead of snowflake.snowpark.functions.random |
| 16450 | + # because the latter does not take Column inputs, but we want to use |
| 16451 | + # pandas_lit() to create the seed. |
| 16452 | + # if random_state is None, we have to call random() with no arguments. |
| 16453 | + # random(NULL) is not valid. |
| 16454 | + builtin_random = builtin("random") |
| 16455 | + random_column = ( |
| 16456 | + builtin_random() |
| 16457 | + if random_state is None |
| 16458 | + else builtin_random(pandas_lit(random_state)) |
| 16459 | + ) |
| 16460 | + if replace: |
| 16461 | + # If `replace=True`, we can't use snowflake's built-in SAMPLE, which |
| 16462 | + # samples without replacement. |
16442 | 16463 | pre_sampling_rowcount = self.get_axis_len(axis=0) |
16443 | 16464 | if n is not None: |
16444 | 16465 | post_sampling_rowcount = n |
16445 | 16466 | else: |
16446 | 16467 | assert frac is not None |
16447 | 16468 | post_sampling_rowcount = round(frac * pre_sampling_rowcount) |
16448 | 16469 |
|
16449 | | - sampled_row_position_col = uniform( |
16450 | | - 0, pre_sampling_rowcount - 1, random() |
16451 | | - ).as_(sampled_row_position_identifier) |
16452 | | - |
| 16470 | + sampled_row_position_identifier = ( |
| 16471 | + generate_snowflake_quoted_identifiers_helper( |
| 16472 | + pandas_labels=[ |
| 16473 | + SAMPLED_ROW_POSITION_COLUMN_LABEL, |
| 16474 | + ] |
| 16475 | + )[0] |
| 16476 | + ) |
16453 | 16477 | sampled_row_positions_snowpark_frame = pd.session.generator( |
16454 | | - sampled_row_position_col, |
| 16478 | + uniform(0, pre_sampling_rowcount - 1, random_column).as_( |
| 16479 | + sampled_row_position_identifier |
| 16480 | + ), |
16455 | 16481 | rowcount=post_sampling_rowcount, |
16456 | 16482 | ) |
16457 | | - |
16458 | 16483 | sampled_row_positions_odf = OrderedDataFrame( |
16459 | 16484 | dataframe_ref=DataFrameReference(sampled_row_positions_snowpark_frame), |
16460 | 16485 | projected_column_snowflake_quoted_identifiers=[ |
16461 | 16486 | sampled_row_position_identifier |
16462 | 16487 | ], |
16463 | 16488 | ) |
16464 | | - sampled_odf = cache_result( |
16465 | | - sampled_row_positions_odf.join( |
16466 | | - right=self._modin_frame.ordered_dataframe, |
16467 | | - left_on_cols=[sampled_row_position_identifier], |
16468 | | - right_on_cols=[ |
16469 | | - self._modin_frame.ordered_dataframe.row_position_snowflake_quoted_identifier |
16470 | | - ], |
| 16489 | + sampled_odf = sampled_row_positions_odf.join( |
| 16490 | + right=self._modin_frame.ordered_dataframe, |
| 16491 | + left_on_cols=[sampled_row_position_identifier], |
| 16492 | + right_on_cols=[ |
| 16493 | + self._modin_frame.ordered_dataframe.row_position_snowflake_quoted_identifier |
| 16494 | + ], |
| 16495 | + ) |
| 16496 | + # if random_state is not None, the result is seeded and already deterministic. |
| 16497 | + if random_state is None: |
| 16498 | + logging.warning( |
| 16499 | + "Snowpark pandas `sample` will create a temp table for " |
| 16500 | + + "sampled results to keep it deterministic." |
| 16501 | + ) |
| 16502 | + sampled_odf = cache_result(sampled_odf) |
| 16503 | + elif random_state is not None: |
| 16504 | + # Snowflake's SAMPLE, while more performant than this appraoch, |
| 16505 | + # only accepts a seed when sampling from a table. A snowflake query |
| 16506 | + # compiler does not necessarily correspond to a particular snowflake |
| 16507 | + # table, and even though we could sample an intermediate table |
| 16508 | + # produced with cache_result(), we need to select a set of rows that |
| 16509 | + # is deterministic with respect to the table length rather than with |
| 16510 | + # respect to the query compiler or even the dataframe. For example, |
| 16511 | + # pd.DataFrame(list(range(1000))).sample(n=1, random_state=0) and |
| 16512 | + # pd.DataFrame(list(range(1000))[::-1]).sample(n=1, random_state=0) |
| 16513 | + # select the same row position. |
| 16514 | + # We use this alternate implementation rather than the generator one |
| 16515 | + # that we use for replace=True because we can avoid a join. |
| 16516 | + if n is not None: |
| 16517 | + post_sampling_rowcount = n |
| 16518 | + else: |
| 16519 | + assert frac is not None |
| 16520 | + pre_sampling_rowcount = self.get_axis_len(axis=0) |
| 16521 | + post_sampling_rowcount = round(frac * pre_sampling_rowcount) |
| 16522 | + # Choose the top `post_sampling_rowcount` rows according to a random |
| 16523 | + # order. |
| 16524 | + new_identifier = self._modin_frame.ordered_dataframe.generate_snowflake_quoted_identifiers( |
| 16525 | + pandas_labels=["random_row_position"] |
| 16526 | + )[ |
| 16527 | + 0 |
| 16528 | + ] |
| 16529 | + sampled_odf = ( |
| 16530 | + self._modin_frame.ordered_dataframe.select( |
| 16531 | + *self._modin_frame.ordered_dataframe.projected_column_snowflake_quoted_identifiers, |
| 16532 | + random_column.as_(new_identifier), |
16471 | 16533 | ) |
| 16534 | + .sort(OrderingColumn(new_identifier)) |
| 16535 | + .limit(post_sampling_rowcount) |
16472 | 16536 | ) |
16473 | 16537 | else: |
16474 | 16538 | sampled_odf = frame.ordered_dataframe.sample(n=n, frac=frac) |
16475 | | - logging.warning( |
16476 | | - "Snowpark pandas `sample` will create a temp table for sampled results to keep it deterministic." |
16477 | | - ) |
16478 | 16539 | res = SnowflakeQueryCompiler( |
16479 | 16540 | InternalFrame.create( |
16480 | 16541 | ordered_dataframe=sampled_odf, |
|
0 commit comments