|
107 | 107 | dense_rank, |
108 | 108 | first_value, |
109 | 109 | floor, |
| 110 | + get, |
110 | 111 | greatest, |
111 | 112 | hour, |
112 | 113 | iff, |
@@ -16813,10 +16814,6 @@ def str_split( |
16813 | 16814 | ErrorMessage.not_implemented( |
16814 | 16815 | "Snowpark pandas doesn't support non-str 'pat' argument" |
16815 | 16816 | ) |
16816 | | - if expand: |
16817 | | - ErrorMessage.not_implemented( |
16818 | | - "Snowpark pandas doesn't support 'expand' argument" |
16819 | | - ) |
16820 | 16817 | if regex: |
16821 | 16818 | ErrorMessage.not_implemented( |
16822 | 16819 | "Snowpark pandas doesn't support 'regex' argument" |
@@ -16864,6 +16861,12 @@ def output_col( |
16864 | 16861 | if np.isnan(n): |
16865 | 16862 | # Follow pandas behavior |
16866 | 16863 | return pandas_lit(np.nan) |
| 16864 | + elif n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1: |
| 16865 | + # Follow pandas behavior, which based on our experiments, leaves the input column as is |
| 16866 | + # whenever the above condition is satisfied. |
| 16867 | + new_col = iff( |
| 16868 | + column.is_null(), pandas_lit(None), array_construct(column) |
| 16869 | + ) |
16867 | 16870 | elif n <= 0: |
16868 | 16871 | # If all possible splits are requested, we just use SQL's split function. |
16869 | 16872 | new_col = builtin("split")(new_col, pandas_lit(new_pat)) |
@@ -16907,9 +16910,93 @@ def output_col( |
16907 | 16910 | ) |
16908 | 16911 | return self._replace_non_str(column, new_col) |
16909 | 16912 |
|
16910 | | - new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns( |
16911 | | - lambda col_name: output_col(col_name, pat, n) |
16912 | | - ) |
| 16913 | + def output_cols( |
| 16914 | + column: SnowparkColumn, pat: Optional[str], n: int, max_splits: int |
| 16915 | + ) -> list[SnowparkColumn]: |
| 16916 | + """ |
| 16917 | + Returns the list of columns that the input column will be split into. |
| 16918 | + This is only used when expand=True. |
| 16919 | + Args: |
| 16920 | + column : SnowparkColumn |
| 16921 | + Input column |
| 16922 | + pat : str |
| 16923 | + String to split on |
| 16924 | + n : int |
| 16925 | + Limit on the number of output splits |
| 16926 | + max_splits : int |
| 16927 | + Maximum number of achievable splits across all values in the input column. |
| 16928 | + This is needed to be able to pad rows with fewer splits than desired with nulls. |
| 16929 | + """ |
| 16930 | + col = output_col(column, pat, n) |
| 16931 | + final_splits = 0 |
| 16932 | + |
| 16933 | + if np.isnan(n): |
| 16934 | + # Follow pandas behavior |
| 16935 | + final_splits = 1 |
| 16936 | + elif n <= 0: |
| 16937 | + final_splits = max_splits |
| 16938 | + else: |
| 16939 | + final_splits = min(n + 1, max_splits) |
| 16940 | + |
| 16941 | + if n < -1 and not pandas.isnull(pat) and len(str(pat)) > 1: |
| 16942 | + # Follow pandas behavior, which based on our experiments, leaves the input column as is |
| 16943 | + # whenever the above condition is satisfied. |
| 16944 | + final_splits = 1 |
| 16945 | + |
| 16946 | + return [ |
| 16947 | + iff( |
| 16948 | + array_size(col) > pandas_lit(i), |
| 16949 | + get(col, pandas_lit(i)), |
| 16950 | + pandas_lit(None), |
| 16951 | + ) |
| 16952 | + for i in range(final_splits) |
| 16953 | + ] |
| 16954 | + |
| 16955 | + def get_max_splits() -> int: |
| 16956 | + """ |
| 16957 | + Returns the maximum number of splits achievable |
| 16958 | + across all values stored in the input column. |
| 16959 | + """ |
| 16960 | + splits_as_list_frame = self.str_split( |
| 16961 | + pat=pat, |
| 16962 | + n=-1, |
| 16963 | + expand=False, |
| 16964 | + regex=regex, |
| 16965 | + )._modin_frame |
| 16966 | + |
| 16967 | + split_counts_frame = splits_as_list_frame.append_column( |
| 16968 | + "split_counts", |
| 16969 | + array_size( |
| 16970 | + col( |
| 16971 | + splits_as_list_frame.data_column_snowflake_quoted_identifiers[0] |
| 16972 | + ) |
| 16973 | + ), |
| 16974 | + ) |
| 16975 | + |
| 16976 | + max_count_rows = split_counts_frame.ordered_dataframe.agg( |
| 16977 | + max_( |
| 16978 | + col(split_counts_frame.data_column_snowflake_quoted_identifiers[-1]) |
| 16979 | + ).as_("max_count") |
| 16980 | + ).collect() |
| 16981 | + |
| 16982 | + return max_count_rows[0][0] |
| 16983 | + |
| 16984 | + if expand: |
| 16985 | + cols = output_cols( |
| 16986 | + col(self._modin_frame.data_column_snowflake_quoted_identifiers[0]), |
| 16987 | + pat, |
| 16988 | + n, |
| 16989 | + get_max_splits(), |
| 16990 | + ) |
| 16991 | + new_internal_frame = self._modin_frame.project_columns( |
| 16992 | + list(range(len(cols))), |
| 16993 | + cols, |
| 16994 | + ) |
| 16995 | + else: |
| 16996 | + new_internal_frame = self._modin_frame.apply_snowpark_function_to_columns( |
| 16997 | + lambda col_name: output_col(col_name, pat, n) |
| 16998 | + ) |
| 16999 | + |
16913 | 17000 | return SnowflakeQueryCompiler(new_internal_frame) |
16914 | 17001 |
|
16915 | 17002 | def str_rsplit( |
|
0 commit comments