Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,22 @@
- `regr_sxy`
- `regr_syy`
- `try_to_binary`
- `base64`
- `base64_decode_string`
- `base64_encode`
- `editdistance`
- `hex`
- `hex_encode`
- `instr`
- `levenshtein`
- `log1p`
- `log2`
- `log10`
- `percentile_approx`
- `unbase64`
- Added support for specifying a schema string (including implicit struct syntax) when calling `DataFrame.create_dataframe`.
- Added support for `DataFrameWriter.insert_into/insertInto`. This method also supports local testing mode.
- Added support for multiple columns in the functions `map_cat` and `map_concat`.

#### Experimental Features

Expand Down
13 changes: 13 additions & 0 deletions docs/source/snowpark/functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ Functions
atanh
atan2
avg
base64
base64_decode_string
base64_encode
bit_length
bitmap_bit_position
bitmap_bucket_number
Expand Down Expand Up @@ -157,6 +160,7 @@ Functions
desc_nulls_last
div0
divnull
editdistance
endswith
equal_nan
equal_null
Expand All @@ -178,12 +182,15 @@ Functions
grouping
grouping_id
hash
hex
hex_encode
hour
iff
ifnull
in_
initcap
insert
instr
is_array
is_binary
is_boolean
Expand Down Expand Up @@ -211,12 +218,16 @@ Functions
least
left
length
levenshtein
listagg
lit
ln
locate
localtimestamp
log
log1p
log2
log10
lower
lpad
ltrim
Expand Down Expand Up @@ -257,6 +268,7 @@ Functions
parse_json
parse_xml
percent_rank
percentile_approx
percentile_cont
position
pow
Expand Down Expand Up @@ -350,6 +362,7 @@ Functions
udaf
udf
udtf
unbase64
uniform
unix_timestamp
upper
Expand Down
238 changes: 235 additions & 3 deletions src/snowflake/snowpark/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@
import functools
import sys
import typing
from functools import reduce
from random import randint
from types import ModuleType
from typing import Callable, Dict, List, Optional, Tuple, Union, overload
Expand Down Expand Up @@ -1345,6 +1346,9 @@ def approx_percentile(
)


percentile_approx = approx_percentile


@publicapi
def approx_percentile_accumulate(col: ColumnOrName, _emit_ast: bool = True) -> Column:
"""Returns the internal representation of the t-Digest state (as a JSON object) at the end of aggregation.
Expand Down Expand Up @@ -2956,6 +2960,63 @@ def log(
return builtin("log", _ast=ast, _emit_ast=False)(b, arg)


@publicapi
def log1p(
x: Union[ColumnOrName, int, float],
_emit_ast: bool = True,
) -> Column:
"""
Returns the natural logarithm of (1 + x).

Example::

>>> df = session.create_dataframe([0, 1], schema=["a"])
>>> df.select(log1p(df["a"]).alias("log1p")).collect()
[Row(LOG1P=0.0), Row(LOG1P=0.6931471805599453)]
"""
x = (
lit(x, _emit_ast=False)
if isinstance(x, (int, float))
else _to_col_if_str(x, "log")
)
one_plus_x = _to_col_if_str(x, "log1p") + lit(1, _emit_ast=False)
return ln(one_plus_x, _emit_ast=_emit_ast)


@publicapi
def log10(
x: Union[ColumnOrName, int, float],
_emit_ast: bool = True,
) -> Column:
"""
Returns the base-10 logarithm of x.

Example::

>>> df = session.create_dataframe([1, 10], schema=["a"])
>>> df.select(log10(df["a"]).alias("log10")).collect()
[Row(LOG10=0.0), Row(LOG10=1.0)]
"""
return _log10(x, _emit_ast=_emit_ast)


@publicapi
def log2(
x: Union[ColumnOrName, int, float],
_emit_ast: bool = True,
) -> Column:
"""
Returns the base-2 logarithm of x.

Example::

>>> df = session.create_dataframe([1, 2, 8], schema=["a"])
>>> df.select(log2(df["a"]).alias("log2")).collect()
[Row(LOG2=0.0), Row(LOG2=1.0), Row(LOG2=3.0)]
"""
return _log2(x, _emit_ast=_emit_ast)


# Create base 2 and base 10 wrappers for use with the Modin log2 and log10 functions
def _log2(x: Union[ColumnOrName, int, float], _emit_ast: bool = True) -> Column:
return log(2, x, _emit_ast=_emit_ast)
Expand Down Expand Up @@ -7112,12 +7173,15 @@ def array_unique_agg(col: ColumnOrName, _emit_ast: bool = True) -> Column:


@publicapi
def map_cat(col1: ColumnOrName, col2: ColumnOrName, _emit_ast: bool = True):
"""Returns the concatenatation of two MAPs.
def map_cat(
col1: ColumnOrName, col2: ColumnOrName, *cols: ColumnOrName, _emit_ast: bool = True
):
"""Returns the concatenatation of two or more MAPs.

Args:
col1: The source map
col2: The map to be appended to col1
cols: More maps to be appended

Example::
>>> df = session.sql("select {'k1': 'v1'} :: MAP(STRING,STRING) as A, {'k2': 'v2'} :: MAP(STRING,STRING) as B")
Expand All @@ -7131,10 +7195,31 @@ def map_cat(col1: ColumnOrName, col2: ColumnOrName, _emit_ast: bool = True):
|} |
---------------------------
<BLANKLINE>
>>> df = session.sql("select {'k1': 'v1'} :: MAP(STRING,STRING) as A, {'k2': 'v2'} :: MAP(STRING,STRING) as B, {'k3': 'v3'} :: MAP(STRING,STRING) as C")
>>> df.select(map_cat("A", "B", "C")).show()
-------------------------------------------
|"MAP_CAT(MAP_CAT(""A"", ""B""), ""C"")" |
-------------------------------------------
|{ |
| "k1": "v1", |
| "k2": "v2", |
| "k3": "v3" |
|} |
-------------------------------------------
<BLANKLINE>
"""
m1 = _to_col_if_str(col1, "map_cat")
m2 = _to_col_if_str(col2, "map_cat")
return builtin("map_cat", _emit_ast=_emit_ast)(m1, m2)
ast = build_function_expr("map_cat", [col1, col2, *cols]) if _emit_ast else None

def map_cat_two_maps(first, second):
return builtin("map_cat", _ast=ast, _emit_ast=_emit_ast)(first, second)

cols_to_concat = [m1, m2]
for c in cols:
cols_to_concat.append(_to_col_if_str(c, "map_cat"))

return reduce(map_cat_two_maps, cols_to_concat)


@publicapi
Expand Down Expand Up @@ -11064,3 +11149,150 @@ def try_to_binary(
if fmt
else builtin("try_to_binary", _emit_ast=_emit_ast)(c)
)


@publicapi
def base64_encode(
e: ColumnOrName,
max_line_length: Optional[int] = 0,
alphabet: Optional[str] = None,
_emit_ast: bool = True,
) -> Column:
"""
Encodes the input (string or binary) using Base64 encoding.

Example:
>>> df = session.create_dataframe(["Snowflake", "Data"], schema=["input"])
>>> df.select(base64_encode(col("input")).alias("encoded")).collect()
[Row(ENCODED='U25vd2ZsYWtl'), Row(ENCODED='RGF0YQ==')]
"""
# Convert input to a column if it is not already one.
ast = (
build_function_expr("base64_encode", [e, max_line_length, alphabet])
if _emit_ast
else None
)
col_input = _to_col_if_str(e, "base64_encode")

# Prepare arguments for the function call.
args = [col_input]

if max_line_length:
args.append(lit(max_line_length))

if alphabet:
args.append(lit(alphabet))

# Call the built-in Base64 encode function.
return builtin("base64_encode", _ast=ast, _emit_ast=_emit_ast)(*args)


base64 = base64_encode


@publicapi
def base64_decode_string(
e: ColumnOrName, alphabet: Optional[str] = None, _emit_ast: bool = True
) -> Column:
"""
Decodes a Base64-encoded string to a string.

Example:
>>> df = session.create_dataframe(["U25vd2ZsYWtl", "SEVMTE8="], schema=["input"])
>>> df.select(base64_decode_string(col("input")).alias("decoded")).collect()
[Row(DECODED='Snowflake'), Row(DECODED='HELLO')]
"""
# Convert input to a column if it is not already one.
ast = (
build_function_expr("base64_decode_string", [e, alphabet])
if _emit_ast
else None
)
col_input = _to_col_if_str(e, "base64_decode_string")

# Prepare arguments for the function call.
args = [col_input]

if alphabet:
args.append(lit(alphabet))

# Call the built-in Base64 encode function.
return builtin("base64_decode_string", _ast=ast, _emit_ast=_emit_ast)(*args)


unbase64 = base64_decode_string


@publicapi
def hex_encode(e: ColumnOrName, case: int = 1, _emit_ast: bool = True):
"""
Encodes the input using hexadecimal (also ‘hex’ or ‘base16’) encoding.

Example:
>>> df = session.create_dataframe(["Snowflake", "Hello"], schema=["input"])
>>> df.select(hex_encode(col("input")).alias("hex_encoded")).collect()
[Row(HEX_ENCODED='536E6F77666C616B65'), Row(HEX_ENCODED='48656C6C6F')]
"""
ast = build_function_expr("hex_encode", [e, case]) if _emit_ast else None
col_input = _to_col_if_str(e, "hex_encode")
return builtin("hex_encode", _ast=ast, _emit_ast=_emit_ast)(col_input, lit(case))


hex = hex_encode


@publicapi
def editdistance(
e1: ColumnOrName,
e2: ColumnOrName,
max_distance: Optional[Union[int, ColumnOrName]] = None,
_emit_ast: bool = True,
) -> Column:
"""Computes the Levenshtein distance between two input strings.

Optionally, a maximum distance can be specified. If the distance exceeds this value,
the computation halts and returns the maximum distance.

Example::

>>> df = session.create_dataframe(
... [["abc", "def"], ["abcdef", "abc"], ["snow", "flake"]],
... schema=["s1", "s2"]
... )
>>> df.select(
... editdistance(col("s1"), col("s2")).alias("distance"),
... editdistance(col("s1"), col("s2"), 2).alias("max_2_distance")
... ).collect()
[Row(DISTANCE=3, MAX_2_DISTANCE=2), Row(DISTANCE=3, MAX_2_DISTANCE=2), Row(DISTANCE=5, MAX_2_DISTANCE=2)]
"""
ast = build_function_expr("editdistance", [e1, e2]) if _emit_ast else None
s1 = _to_col_if_str(e1, "editdistance")
s2 = _to_col_if_str(e2, "editdistance")

args = [s1, s2]
if max_distance is not None:
max_dist = (
lit(max_distance)
if isinstance(max_distance, int)
else _to_col_if_str(max_distance, "editdistance")
)
args.append(max_dist)

return builtin("editdistance", _ast=ast, _emit_ast=_emit_ast)(*args)


levenshtein = editdistance


@publicapi
def instr(str: ColumnOrName, substr: str):
"""
Locate the position of the first occurrence of substr column in the given string. Returns null if either of the arguments are null.

Example::
>>> df = session.create_dataframe([["hello world"], ["world hello"]], schema=["text"])
>>> df.select(instr(col("text"), "world").alias("position")).collect()
[Row(POSITION=7), Row(POSITION=1)]
"""
s1 = _to_col_if_str(str, "instr")
return position(lit(substr), s1)
10 changes: 10 additions & 0 deletions tests/ast/data/functions2.test
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,16 @@ df304 = df.select(locate("needle", col("expr")), locate("needle", lit("test stri

df305 = df.select(size(col("expr")), size("A"))

df306 = df.select(base64_encode("A"))

df307 = df.select(base64_decode_string("A"))

df308 = df.select(hex_encode("A"))

df309 = df.select(editdistance("A", "B"))

df310 = df.select(map_cat("A", "B"))

## EXPECTED UNPARSER OUTPUT

df = session.table("table1")
Expand Down
Loading