Skip to content

Commit 23285ef

Browse files
authored
NO-SNOW: fix flaky test failures (#4010)
1 parent c5f1816 commit 23285ef

File tree

4 files changed

+13
-11
lines changed

4 files changed

+13
-11
lines changed

src/snowflake/snowpark/dataframe_analytics_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -496,15 +496,15 @@ def compute_lag(
496496
... lags=[1, 2],
497497
... order_by=["ORDERDATE"],
498498
... group_by=["PRODUCTKEY"],
499-
... )
499+
... ).sort("ORDERDATE")
500500
>>> res.show()
501501
------------------------------------------------------------------------------------------
502502
|"ORDERDATE" |"PRODUCTKEY" |"SALESAMOUNT" |"SALESAMOUNT_LAG_1" |"SALESAMOUNT_LAG_2" |
503503
------------------------------------------------------------------------------------------
504-
|2023-01-04 |102 |250 |NULL |NULL |
505504
|2023-01-01 |101 |200 |NULL |NULL |
506505
|2023-01-02 |101 |100 |200 |NULL |
507506
|2023-01-03 |101 |300 |100 |200 |
507+
|2023-01-04 |102 |250 |NULL |NULL |
508508
------------------------------------------------------------------------------------------
509509
<BLANKLINE>
510510
"""

src/snowflake/snowpark/functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8895,15 +8895,15 @@ def rank(_emit_ast: bool = True) -> Column:
88958895
... ],
88968896
... schema=["x", "y", "z"]
88978897
... )
8898-
>>> df.select(rank().over(Window.partition_by(col("X")).order_by(col("Y"))).alias("result")).show()
8898+
>>> df.select(rank().over(Window.partition_by(col("X")).order_by(col("Y"))).alias("result")).sort("result").show()
88998899
------------
89008900
|"RESULT" |
89018901
------------
89028902
|1 |
8903-
|2 |
8904-
|2 |
89058903
|1 |
89068904
|1 |
8905+
|2 |
8906+
|2 |
89078907
------------
89088908
<BLANKLINE>
89098909
"""
@@ -9020,8 +9020,8 @@ def lag(
90209020
... ],
90219021
... schema=["x", "y", "z"]
90229022
... )
9023-
>>> df.select(lag("Z").over(Window.partition_by(col("X")).order_by(col("Y"))).alias("result")).collect()
9024-
[Row(RESULT=None), Row(RESULT=10), Row(RESULT=1), Row(RESULT=None), Row(RESULT=1)]
9023+
>>> df.select(lag("Z").over(Window.partition_by(col("X")).order_by(col("Y"))).alias("result")).sort("result").collect()
9024+
[Row(RESULT=None), Row(RESULT=None), Row(RESULT=1), Row(RESULT=1), Row(RESULT=10)]
90259025
"""
90269026
# AST.
90279027
ast = (
@@ -11732,7 +11732,7 @@ def regr_avgx(y: ColumnOrName, x: ColumnOrName, _emit_ast: bool = True) -> Colum
1173211732
Example::
1173311733

1173411734
>>> df = session.create_dataframe([[10, 11], [20, 22], [25, None], [30, 35]], schema=["v", "v2"])
11735-
>>> df.groupBy("v").agg(regr_avgx(df["v"], df["v2"]).alias("regr_avgx")).collect()
11735+
>>> df.groupBy("v").agg(regr_avgx(df["v"], df["v2"]).alias("regr_avgx")).sort("v").collect()
1173611736
[Row(V=10, REGR_AVGX=11.0), Row(V=20, REGR_AVGX=22.0), Row(V=25, REGR_AVGX=None), Row(V=30, REGR_AVGX=35.0)]
1173711737
"""
1173811738
c1 = _to_col_if_str(y, "regr_avgx")
@@ -11766,7 +11766,7 @@ def regr_count(y: ColumnOrName, x: ColumnOrName, _emit_ast: bool = True) -> Colu
1176611766
Example::
1176711767

1176811768
>>> df = session.create_dataframe([[1, 10, 11], [1, 20, 22], [1, 25, None], [2, 30, 35]], schema=["k", "v", "v2"])
11769-
>>> df.group_by("k").agg(regr_count(col("v"), col("v2")).alias("regr_count")).collect()
11769+
>>> df.group_by("k").agg(regr_count(col("v"), col("v2")).alias("regr_count")).sort("k").collect()
1177011770
[Row(K=1, REGR_COUNT=2), Row(K=2, REGR_COUNT=1)]
1177111771
"""
1177211772
c1 = _to_col_if_str(y, "regr_count")

src/snowflake/snowpark/session.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1794,6 +1794,7 @@ def replicate_local_environment(
17941794
17951795
Example::
17961796
1797+
>>> import sys, pytest; _ = (sys.version_info[:2] != (3, 9)) or pytest.skip()
17971798
>>> from snowflake.snowpark.functions import udf
17981799
>>> import numpy
17991800
>>> import pandas
@@ -2714,6 +2715,7 @@ def table_function(
27142715
Example 1
27152716
Query a table function by function name:
27162717
2718+
>>> import sys, pytest; _ = (sys.version_info[:2] != (3, 9)) or pytest.skip()
27172719
>>> from snowflake.snowpark.functions import lit
27182720
>>> session.table_function("split_to_table", lit("split words to table"), lit(" ")).collect()
27192721
[Row(SEQ=1, INDEX=1, VALUE='split'), Row(SEQ=1, INDEX=2, VALUE='words'), Row(SEQ=1, INDEX=3, VALUE='to'), Row(SEQ=1, INDEX=4, VALUE='table')]

tests/integ/scala/test_dataframe_aggregate_suite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,14 +556,14 @@ def test_group_by(session):
556556
.with_column("medical_license", lit(None))
557557
.select("medical_license", "radio_license", "count")
558558
)
559-
.sort(col("count"))
559+
.sort(col("count"), col("radio_license"))
560560
.collect()
561561
)
562562
Utils.check_answer(
563563
result,
564564
[
565-
Row(None, "General", 1),
566565
Row(None, "Amateur Extra", 1),
566+
Row(None, "General", 1),
567567
Row("RN", None, 2),
568568
Row(None, "Technician", 2),
569569
Row(None, None, 3),

0 commit comments

Comments
 (0)