Skip to content

Commit 1ad799d

Browse files
authored
SNOW-1901483: function api coverage normal and random impl (#2968)
1 parent 54eba82 commit 1ad799d

File tree

4 files changed

+230
-0
lines changed

4 files changed

+230
-0
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
# Release History
22

3+
# 1.28.0 (TBD)
4+
5+
### Snowpark Python API Updates
6+
7+
#### New Features
8+
9+
- Added support for the following functions in `functions.py`
10+
- `normal`
11+
- `randn`
12+
313
## 1.27.0 (2025-02-03)
414

515
### Snowpark Python API Updates

docs/source/snowpark/functions.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ Functions
250250
months_between
251251
negate
252252
next_day
253+
normal
253254
not_
254255
nth_value
255256
ntile
@@ -275,6 +276,7 @@ Functions
275276
previous_day
276277
quarter
277278
radians
279+
randn
278280
random
279281
rank
280282
regexp_count

src/snowflake/snowpark/functions.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11415,3 +11415,61 @@ def instr(str: ColumnOrName, substr: str, _emit_ast: bool = True):
1141511415
ast = build_function_expr("instr", [str, substr]) if _emit_ast else None
1141611416
s1 = _to_col_if_str(str, "instr")
1141711417
return position(lit(substr), s1, _emit_ast=False, _ast=ast)
11418+
11419+
11420+
@publicapi
11421+
def normal(
11422+
mean: Union[int, float],
11423+
stddev: Union[int, float],
11424+
gen: Union[ColumnOrName, int, float],
11425+
_emit_ast: bool = True,
11426+
_ast: Optional[proto.Expr] = None,
11427+
):
11428+
"""
11429+
Generates a normally-distributed pseudo-random floating point number with specified mean and stddev (standard deviation).
11430+
11431+
Example::
11432+
>>> df = session.create_dataframe([1,2,3], schema=["a"])
11433+
>>> df.select(normal(0, 1, "a").alias("normal")).collect()
11434+
[Row(NORMAL=-1.143416214223267), Row(NORMAL=-0.78469958830255), Row(NORMAL=-0.365971322006404)]
11435+
"""
11436+
# SNOW-1906511: normal function does not support passing mean and stddev as column name in the following way:
11437+
# the following fails: SELECT normal("A", "A", 2) FROM ( SELECT $1 AS "A" FROM VALUES (0 :: BIGINT))
11438+
# but it supports reading from a table, we don't do type validation on mean and stddev here so users can still
11439+
# use the functions on normal table
11440+
ast = build_function_expr("normal", [mean, stddev, gen]) if _emit_ast else _ast
11441+
mean = lit(mean, _emit_ast=False) if isinstance(mean, (int, float)) else mean
11442+
stddev = (
11443+
lit(stddev, _emit_ast=False) if isinstance(stddev, (int, float)) else stddev
11444+
)
11445+
gen = (
11446+
lit(gen, _emit_ast=False)
11447+
if isinstance(gen, (int, float))
11448+
else _to_col_if_str(gen, "normal")
11449+
)
11450+
return builtin("normal", _emit_ast=_emit_ast, _ast=ast)(mean, stddev, gen)
11451+
11452+
11453+
@publicapi
11454+
def randn(
11455+
seed: Optional[Union[ColumnOrName, int, float]] = None, _emit_ast: bool = True
11456+
) -> Column:
11457+
"""
11458+
Generates a column with independent and identically distributed (i.i.d.) samples from the standard normal distribution.
11459+
11460+
Example::
11461+
>>> df = session.create_dataframe([1,2,3], schema=["seed"])
11462+
>>> df.select(randn("seed").alias("randn")).collect()
11463+
[Row(RANDN=-1.143416214223267), Row(RANDN=-0.78469958830255), Row(RANDN=-0.365971322006404)]
11464+
>>> df.select(randn().alias("randn")).collect() # doctest: +SKIP
11465+
"""
11466+
ast = build_function_expr("randn", [seed]) if _emit_ast else None
11467+
if seed is None:
11468+
seed = random(_emit_ast=False) # pragma: no cover
11469+
return normal(
11470+
lit(0, _emit_ast=False),
11471+
lit(1, _emit_ast=False),
11472+
seed,
11473+
_emit_ast=False,
11474+
_ast=ast,
11475+
)

tests/ast/data/functions2.test

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ df315 = df.select(nth_value("A", 2), nth_value("A", 2, True), nth_value(col("B")
324324

325325
df316 = df.select(bitshiftright_unsigned("A", 2), bitshiftright_unsigned("A", col("B")))
326326

327+
df317 = df.select(normal(1, 2, "A"))
328+
329+
df318 = df.select(randn(1))
330+
327331
## EXPECTED UNPARSER OUTPUT
328332

329333
df = session.table("table1")
@@ -648,6 +652,10 @@ df315 = df.select(nth_value("A", 2, False), nth_value("A", 2, True), nth_value(c
648652

649653
df316 = df.select(bitshiftright_unsigned("A", 2), bitshiftright_unsigned("A", col("B")))
650654

655+
df317 = df.select(normal(1, 2, "A"))
656+
657+
df318 = df.select(randn(1))
658+
651659
## EXPECTED ENCODED AST
652660

653661
interned_value_table {
@@ -26419,6 +26427,158 @@ body {
2641926427
}
2642026428
}
2642126429
}
26430+
body {
26431+
assign {
26432+
expr {
26433+
sp_dataframe_select__columns {
26434+
cols {
26435+
apply_expr {
26436+
fn {
26437+
builtin_fn {
26438+
name {
26439+
name {
26440+
sp_name_flat {
26441+
name: "normal"
26442+
}
26443+
}
26444+
}
26445+
}
26446+
}
26447+
pos_args {
26448+
int64_val {
26449+
src {
26450+
end_column: 43
26451+
end_line: 349
26452+
file: 2
26453+
start_column: 26
26454+
start_line: 349
26455+
}
26456+
v: 1
26457+
}
26458+
}
26459+
pos_args {
26460+
int64_val {
26461+
src {
26462+
end_column: 43
26463+
end_line: 349
26464+
file: 2
26465+
start_column: 26
26466+
start_line: 349
26467+
}
26468+
v: 2
26469+
}
26470+
}
26471+
pos_args {
26472+
string_val {
26473+
src {
26474+
end_column: 43
26475+
end_line: 349
26476+
file: 2
26477+
start_column: 26
26478+
start_line: 349
26479+
}
26480+
v: "A"
26481+
}
26482+
}
26483+
src {
26484+
end_column: 43
26485+
end_line: 349
26486+
file: 2
26487+
start_column: 26
26488+
start_line: 349
26489+
}
26490+
}
26491+
}
26492+
df {
26493+
sp_dataframe_ref {
26494+
id {
26495+
bitfield1: 1
26496+
}
26497+
}
26498+
}
26499+
src {
26500+
end_column: 44
26501+
end_line: 349
26502+
file: 2
26503+
start_column: 16
26504+
start_line: 349
26505+
}
26506+
variadic: true
26507+
}
26508+
}
26509+
symbol {
26510+
value: "df317"
26511+
}
26512+
uid: 162
26513+
var_id {
26514+
bitfield1: 162
26515+
}
26516+
}
26517+
}
26518+
body {
26519+
assign {
26520+
expr {
26521+
sp_dataframe_select__columns {
26522+
cols {
26523+
apply_expr {
26524+
fn {
26525+
builtin_fn {
26526+
name {
26527+
name {
26528+
sp_name_flat {
26529+
name: "randn"
26530+
}
26531+
}
26532+
}
26533+
}
26534+
}
26535+
pos_args {
26536+
int64_val {
26537+
src {
26538+
end_column: 34
26539+
end_line: 351
26540+
file: 2
26541+
start_column: 26
26542+
start_line: 351
26543+
}
26544+
v: 1
26545+
}
26546+
}
26547+
src {
26548+
end_column: 34
26549+
end_line: 351
26550+
file: 2
26551+
start_column: 26
26552+
start_line: 351
26553+
}
26554+
}
26555+
}
26556+
df {
26557+
sp_dataframe_ref {
26558+
id {
26559+
bitfield1: 1
26560+
}
26561+
}
26562+
}
26563+
src {
26564+
end_column: 35
26565+
end_line: 351
26566+
file: 2
26567+
start_column: 16
26568+
start_line: 351
26569+
}
26570+
variadic: true
26571+
}
26572+
}
26573+
symbol {
26574+
value: "df318"
26575+
}
26576+
uid: 163
26577+
var_id {
26578+
bitfield1: 163
26579+
}
26580+
}
26581+
}
2642226582
client_ast_version: 1
2642326583
client_language {
2642426584
python_language {

0 commit comments

Comments
 (0)