Skip to content

Commit 550e5b2

Browse files
SNOW-2334682: Add support for scalar functions from different categories. (#3773)
1 parent 3e9305e commit 550e5b2

File tree

3 files changed

+280
-0
lines changed

3 files changed

+280
-0
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,13 @@
1818
- `get_cloud_provider_token`
1919

2020
- Added support for the following scalar functions in `functions.py`:
21+
- `array_remove_at`
22+
- `as_boolean`
23+
- `boolor_agg`
24+
- `chr`
25+
- `div0null`
26+
- `dp_interval_high`
27+
- `dp_interval_low`
2128
- `h3_cell_to_boundary`
2229
- `h3_cell_to_parent`
2330
- `h3_cell_to_point`
@@ -28,6 +35,9 @@
2835
- `h3_get_resolution`
2936
- `h3_grid_disk`
3037
- `h3_grid_distance`
38+
- `hex_decode_binary`
39+
- `last_query_id`
40+
- `last_transaction`
3141

3242
### Snowpark pandas API Updates
3343

docs/source/snowpark/functions.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ Functions
5959
array_position
6060
array_prepend
6161
array_remove
62+
array_remove_at
6263
array_reverse
6364
array_size
6465
array_slice
@@ -70,6 +71,7 @@ Functions
7071
arrays_zip
7172
as_array
7273
as_binary
74+
as_boolean
7375
as_char
7476
as_date
7577
as_decimal
@@ -109,6 +111,7 @@ Functions
109111
bitshiftright
110112
bitxor
111113
bitxor_agg
114+
boolor_agg
112115
build_stage_file_url
113116
builtin
114117
bround
@@ -123,6 +126,7 @@ Functions
123126
charindex
124127
check_json
125128
check_xml
129+
chr
126130
coalesce
127131
col
128132
collate
@@ -187,6 +191,8 @@ Functions
187191
desc_nulls_last
188192
div0
189193
divnull
194+
dp_interval_high
195+
dp_interval_low
190196
editdistance
191197
endswith
192198
equal_nan
@@ -227,6 +233,7 @@ Functions
227233
grouping_id
228234
hash
229235
hex
236+
hex_decode_binary
230237
hex_encode
231238
hour
232239
h3_cell_to_boundary
@@ -275,6 +282,8 @@ Functions
275282
kurtosis
276283
lag
277284
last_day
285+
last_query_id
286+
last_transaction
278287
last_value
279288
lead
280289
least

src/snowflake/snowpark/_functions/scalar_functions.py

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,3 +763,264 @@ def h3_grid_distance(
763763
cell_id_1 = _to_col_if_str(cell_id_1, "h3_grid_distance")
764764
cell_id_2 = _to_col_if_str(cell_id_2, "h3_grid_distance")
765765
return builtin("h3_grid_distance", _emit_ast=_emit_ast)(cell_id_1, cell_id_2)
766+
767+
768+
@publicapi
769+
def array_remove_at(
770+
array: ColumnOrName, position: ColumnOrName, _emit_ast: bool = True
771+
) -> Column:
772+
"""
773+
Returns an ARRAY with the element at the specified position removed.
774+
775+
Args:
776+
array (ColumnOrName): Column containing the source ARRAY.
777+
position (ColumnOrName): Column containing a (zero-based) position in the source ARRAY.
778+
The element at this position is removed from the resulting ARRAY.
779+
A negative position is interpreted as an index from the back of the array (e.g. -1 removes the last element in the array).
780+
781+
Returns:
782+
Column: The resulting ARRAY with the specified element removed.
783+
784+
Example::
785+
786+
>>> df = session.create_dataframe([([2, 5, 7], 0), ([2, 5, 7], -1), ([2, 5, 7], 10)], schema=["array_col", "position_col"])
787+
>>> df.select(array_remove_at("array_col", "position_col").alias("result")).collect()
788+
[Row(RESULT='[\\n 5,\\n 7\\n]'), Row(RESULT='[\\n 2,\\n 5\\n]'), Row(RESULT='[\\n 2,\\n 5,\\n 7\\n]')]
789+
"""
790+
a = _to_col_if_str(array, "array_remove_at")
791+
p = _to_col_if_str(position, "array_remove_at")
792+
return builtin("array_remove_at", _emit_ast=_emit_ast)(a, p)
793+
794+
795+
@publicapi
796+
def as_boolean(variant: ColumnOrName, _emit_ast: bool = True) -> Column:
797+
"""
798+
Casts a VARIANT value to a boolean.
799+
800+
Args:
801+
variant (ColumnOrName): A Column or column name containing VARIANT values to be cast to boolean.
802+
803+
Returns:
804+
ColumnL The boolean values cast from the VARIANT input.
805+
806+
Example::
807+
>>> from snowflake.snowpark.functions import to_variant, to_boolean
808+
>>> df = session.create_dataframe([
809+
... [True],
810+
... [False]
811+
... ], schema=["a"])
812+
>>> df.select(as_boolean(to_variant(to_boolean(df["a"]))).alias("result")).collect()
813+
[Row(RESULT=True), Row(RESULT=False)]
814+
"""
815+
c = _to_col_if_str(variant, "as_boolean")
816+
return builtin("as_boolean", _emit_ast=_emit_ast)(c)
817+
818+
819+
@publicapi
820+
def boolor_agg(e: ColumnOrName, _emit_ast: bool = True) -> Column:
821+
"""
822+
Returns the logical OR of all non-NULL records in a group. If all records are NULL, returns NULL.
823+
824+
Args:
825+
e (ColumnOrName): Boolean values to aggregate.
826+
827+
Returns:
828+
Column: The logical OR aggregation result.
829+
830+
Example::
831+
832+
>>> df = session.create_dataframe([
833+
... [True, False, True],
834+
... [False, False, False],
835+
... [True, True, False],
836+
... [False, True, True]
837+
... ], schema=["a", "b", "c"])
838+
>>> df.select(
839+
... boolor_agg(df["a"]).alias("boolor_a"),
840+
... boolor_agg(df["b"]).alias("boolor_b"),
841+
... boolor_agg(df["c"]).alias("boolor_c")
842+
... ).collect()
843+
[Row(BOOLOR_A=True, BOOLOR_B=True, BOOLOR_C=True)]
844+
"""
845+
c = _to_col_if_str(e, "boolor_agg")
846+
return builtin("boolor_agg", _emit_ast=_emit_ast)(c)
847+
848+
849+
@publicapi
850+
def chr(col: ColumnOrName, _emit_ast: bool = True) -> Column:
851+
"""
852+
Converts a Unicode code point (including 7-bit ASCII) into the character that matches the input Unicode.
853+
854+
Args:
855+
col (ColumnOrName): Integer Unicode code points.
856+
857+
Returns:
858+
Column: The corresponding character for each code point.
859+
860+
Example::
861+
862+
>>> df = session.create_dataframe([83, 33, 169, 8364, None], schema=['a'])
863+
>>> df.select(df.a, chr(df.a).as_('char')).sort(df.a).show()
864+
-----------------
865+
|"A" |"CHAR" |
866+
-----------------
867+
|NULL |NULL |
868+
|33 |! |
869+
|83 |S |
870+
|169 |© |
871+
|8364 |€ |
872+
-----------------
873+
<BLANKLINE>
874+
"""
875+
c = _to_col_if_str(col, "chr")
876+
return builtin("chr", _emit_ast=_emit_ast)(c)
877+
878+
879+
@publicapi
880+
def div0null(
881+
dividend: Union[ColumnOrName, int, float],
882+
divisor: Union[ColumnOrName, int, float],
883+
_emit_ast: bool = True,
884+
) -> Column:
885+
"""
886+
Performs division like the division operator (/), but returns 0 when the divisor is 0 or NULL (rather than reporting an error).
887+
888+
Args:
889+
dividend (ColumnOrName, int, float): The dividend.
890+
divisor (ColumnOrName, int, float): The divisor.
891+
892+
Returns:
893+
Column: The result of the division, with 0 returned for cases where the divisor is 0 or NULL.
894+
895+
Example::
896+
897+
>>> df = session.create_dataframe([[10, 2], [10, 0], [10, None]], schema=["dividend", "divisor"])
898+
>>> df.select(div0null(df["dividend"], df["divisor"]).alias("result")).collect()
899+
[Row(RESULT=Decimal('5.000000')), Row(RESULT=Decimal('0.000000')), Row(RESULT=Decimal('0.000000'))]
900+
"""
901+
dividend_col = (
902+
lit(dividend)
903+
if isinstance(dividend, (int, float))
904+
else _to_col_if_str(dividend, "div0null")
905+
)
906+
divisor_col = (
907+
lit(divisor)
908+
if isinstance(divisor, (int, float))
909+
else _to_col_if_str(divisor, "div0null")
910+
)
911+
return builtin("div0null", _emit_ast=_emit_ast)(dividend_col, divisor_col)
912+
913+
914+
@publicapi
915+
def dp_interval_high(aggregated_column: ColumnOrName, _emit_ast: bool = True) -> Column:
916+
"""
917+
Returns the high end of the confidence interval for a differentially private aggregate.
918+
This function is used with differential privacy aggregation functions to provide
919+
the upper bound of the confidence interval for the aggregated result.
920+
921+
Args:
922+
aggregated_column (ColumnOrName): The result of a differential privacy aggregation function.
923+
924+
Returns:
925+
Column: The high end of the confidence interval for the differentially private aggregate.
926+
927+
Example::
928+
929+
>>> from snowflake.snowpark.functions import sum as sum_
930+
>>> df = session.create_dataframe([[10], [20], [30]], schema=["num_claims"])
931+
>>> df.select(sum_(df["num_claims"]).alias("sum_claims")).select(dp_interval_high("sum_claims")).collect()
932+
[Row(DP_INTERVAL_HIGH("SUM_CLAIMS")=None)]
933+
"""
934+
c = _to_col_if_str(aggregated_column, "dp_interval_high")
935+
return builtin("dp_interval_high", _emit_ast=_emit_ast)(c)
936+
937+
938+
@publicapi
939+
def dp_interval_low(aggregated_column: ColumnOrName, _emit_ast: bool = True) -> Column:
940+
"""
941+
Returns the lower bound of the confidence interval for a differentially private aggregate. This function is used with differential privacy aggregation functions to provide statistical bounds on the results.
942+
943+
Args:
944+
aggregated_column (ColumnOrName): The differentially private aggregate result.
945+
946+
Returns:
947+
Column: The lower bound of the confidence interval.
948+
949+
Example::
950+
951+
>>> from snowflake.snowpark.functions import sum as sum_
952+
>>> df = session.create_dataframe([[10], [20], [30]], schema=["num_claims"])
953+
>>> result = df.select(sum_("num_claims").alias("sum_claims")).select(dp_interval_low("sum_claims").alias("interval_low"))
954+
>>> result.collect()
955+
[Row(INTERVAL_LOW=None)]
956+
"""
957+
c = _to_col_if_str(aggregated_column, "dp_interval_low")
958+
return builtin("dp_interval_low", _emit_ast=_emit_ast)(c)
959+
960+
961+
@publicapi
962+
def hex_decode_binary(input_expr: ColumnOrName, _emit_ast: bool = True) -> Column:
963+
"""
964+
Decodes a hex-encoded string to binary data.
965+
966+
Args:
967+
input_expr (:class:`ColumnOrName`): the hex-encoded string to decode.
968+
Returns:
969+
:class:`Column`: the decoded binary data.
970+
971+
Example::
972+
973+
>>> df = session.create_dataframe(['48454C4C4F', '576F726C64'], schema=['hex_string'])
974+
>>> df.select(hex_decode_binary(df['hex_string']).alias('decoded_binary')).collect()
975+
[Row(DECODED_BINARY=bytearray(b'HELLO')), Row(DECODED_BINARY=bytearray(b'World'))]
976+
"""
977+
c = _to_col_if_str(input_expr, "hex_decode_binary")
978+
return builtin("hex_decode_binary", _emit_ast=_emit_ast)(c)
979+
980+
981+
@publicapi
982+
def last_query_id(num: ColumnOrName = None, _emit_ast: bool = True) -> Column:
983+
"""
984+
Returns the query ID of the last statement executed in the current session.
985+
If num is specified, returns the query ID of the nth statement executed in the current session.
986+
987+
Args:
988+
num (ColumnOrName, optional): The number of statements back to retrieve the query ID for. If None, returns the query ID of the last statement.
989+
990+
Returns:
991+
Column: The query ID as a string.
992+
993+
Example::
994+
995+
>>> df = session.create_dataframe([1], schema=["a"])
996+
>>> result1 = df.select(last_query_id().alias("QUERY_ID")).collect()
997+
>>> assert len(result1) == 1
998+
>>> assert isinstance(result1[0]["QUERY_ID"], str)
999+
>>> assert len(result1[0]["QUERY_ID"]) > 0
1000+
>>> result2 = df.select(last_query_id(1).alias("QUERY_ID")).collect()
1001+
>>> assert len(result2) == 1
1002+
>>> assert isinstance(result2[0]["QUERY_ID"], str)
1003+
>>> assert len(result2[0]["QUERY_ID"]) > 0
1004+
"""
1005+
if num is None:
1006+
return builtin("last_query_id", _emit_ast=_emit_ast)()
1007+
else:
1008+
return builtin("last_query_id", _emit_ast=_emit_ast)(num)
1009+
1010+
1011+
@publicapi
1012+
def last_transaction(_emit_ast: bool = True) -> Column:
1013+
"""
1014+
Returns the query ID of the last transaction committed or rolled back in the current session. If no transaction has been committed or rolled back in the current session, returns NULL.
1015+
1016+
Returns:
1017+
Column: The last transaction.
1018+
1019+
Example::
1020+
1021+
>>> df = session.create_dataframe([1])
1022+
>>> result = df.select(last_transaction()).collect()
1023+
>>> # Result will be None if no transaction has occurred
1024+
>>> assert result[0]['LAST_TRANSACTION()'] is None or isinstance(result[0]['LAST_TRANSACTION()'], str)
1025+
"""
1026+
return builtin("last_transaction", _emit_ast=_emit_ast)()

0 commit comments

Comments
 (0)