Skip to content

Commit 5899e71

Browse files
NO SNOW: using mock instead of altering value in snowpark.context (#4025)
1 parent af1eef4 commit 5899e71

File tree

2 files changed

+105
-117
lines changed

2 files changed

+105
-117
lines changed

tests/integ/test_df_aggregate.py

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55
import decimal
66
import math
7+
from unittest import mock
78

89
import pytest
910

@@ -31,7 +32,6 @@
3132
)
3233
from snowflake.snowpark.mock._snowflake_data_type import ColumnEmulator, ColumnType
3334
from snowflake.snowpark.types import DoubleType, IntegerType, StructType, StructField
34-
import snowflake.snowpark.context as context
3535
from tests.utils import Utils
3636

3737

@@ -644,10 +644,9 @@ def test_agg_on_empty_df(session):
644644
reason="HAVING clause is not supported in local testing mode",
645645
)
646646
def test_agg_filter_snowpark_connect_compatible(session):
647-
original_value = context._is_snowpark_connect_compatible_mode
648-
649-
try:
650-
context._is_snowpark_connect_compatible_mode = True
647+
with mock.patch(
648+
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True
649+
):
651650
df = session.create_dataframe(
652651
[(1, 2, 3), (3, 2, 1), (3, 2, 1)], ["a", "b", "c"]
653652
)
@@ -679,19 +678,16 @@ def test_agg_filter_snowpark_connect_compatible(session):
679678
df.filter(grouping("a") == 0).collect()
680679

681680
Utils.check_answer(df.filter(col("a") > 1), [Row(3, 2, 1), Row(3, 2, 1)])
682-
finally:
683-
context._is_snowpark_connect_compatible_mode = original_value
684681

685682

686683
@pytest.mark.skipif(
687684
"config.getoption('local_testing_mode', default=False)",
688685
reason="ORDER BY append is not supported in local testing mode",
689686
)
690687
def test_agg_sort_snowpark_connect_compatible(session):
691-
original_value = context._is_snowpark_connect_compatible_mode
692-
693-
try:
694-
context._is_snowpark_connect_compatible_mode = True
688+
with mock.patch(
689+
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True
690+
):
695691
df = session.create_dataframe(
696692
[(1, 2, 3), (3, 2, 1), (3, 2, 1)], ["a", "b", "c"]
697693
)
@@ -717,32 +713,27 @@ def test_agg_sort_snowpark_connect_compatible(session):
717713
# original behavior on dataframe without group by
718714
df4 = df.sort(col("a"))
719715
Utils.check_answer(df4, [Row(1, 2, 3), Row(3, 2, 1), Row(3, 2, 1)])
720-
finally:
721-
context._is_snowpark_connect_compatible_mode = original_value
722716

723717

724718
def test_agg_no_grouping_exprs_limit_snowpark_connect_compatible(session):
725-
original_value = context._is_snowpark_connect_compatible_mode
726-
try:
727-
context._is_snowpark_connect_compatible_mode = True
719+
with mock.patch(
720+
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True
721+
):
728722
df = session.create_dataframe([[1, 2], [3, 4], [1, 4]], schema=["A", "B"])
729723
result = df.agg(sum_(col("a"))).limit(2)
730724
Utils.check_answer(result, [Row(5)])
731725
result = df.group_by().agg(sum_(col("b"))).limit(2)
732726
Utils.check_answer(result, [Row(10)])
733-
finally:
734-
context._is_snowpark_connect_compatible_mode = original_value
735727

736728

737729
@pytest.mark.skipif(
738730
"config.getoption('local_testing_mode', default=False)",
739731
reason="HAVING and ORDER BY append are not supported in local testing mode",
740732
)
741733
def test_agg_filter_and_sort_with_grouping_snowpark_connect_compatible(session):
742-
original_value = context._is_snowpark_connect_compatible_mode
743-
744-
try:
745-
context._is_snowpark_connect_compatible_mode = True
734+
with mock.patch(
735+
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True
736+
):
746737
df = session.create_dataframe(
747738
[
748739
("dotNET", 2012, 10000),
@@ -852,19 +843,16 @@ def test_agg_filter_and_sort_with_grouping_snowpark_connect_compatible(session):
852843
# First row should have highest grouping value (1)
853844
results6 = df6.collect()
854845
assert results6[0][2] == 1 # gc=1 for NULL course
855-
finally:
856-
context._is_snowpark_connect_compatible_mode = original_value
857846

858847

859848
@pytest.mark.skipif(
860849
"config.getoption('local_testing_mode', default=False)",
861850
reason="HAVING, ORDER BY append, and limit append are not supported in local testing mode",
862851
)
863852
def test_filter_sort_limit_snowpark_connect_compatible(session, sql_simplifier_enabled):
864-
original_value = context._is_snowpark_connect_compatible_mode
865-
866-
try:
867-
context._is_snowpark_connect_compatible_mode = True
853+
with mock.patch(
854+
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True
855+
):
868856
df = session.create_dataframe(
869857
[(1, 2, 3), (3, 2, 1), (3, 2, 1)], ["a", "b", "c"]
870858
)
@@ -939,9 +927,6 @@ def test_filter_sort_limit_snowpark_connect_compatible(session, sql_simplifier_e
939927
# Should have 4 SELECT statements
940928
assert query6.upper().count("SELECT") == 4 if sql_simplifier_enabled else 5
941929

942-
finally:
943-
context._is_snowpark_connect_compatible_mode = original_value
944-
945930

946931
@pytest.mark.skipif(
947932
"config.getoption('local_testing_mode', default=False)",

tests/integ/test_udtf.py

Lines changed: 89 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
import sys
99
from textwrap import dedent
1010
from typing import Dict, List, Tuple
11+
from unittest import mock
1112

1213
import pytest
1314

14-
from snowflake.snowpark import Row, Table, context
15+
from snowflake.snowpark import Row, Table
1516
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
1617
from snowflake.snowpark._internal.utils import TempObjectType
1718
from snowflake.snowpark.exceptions import SnowparkSQLException
@@ -557,100 +558,102 @@ def __init__(self) -> None:
557558
df._column_map.columns = [Column("id"), Column("v")]
558559
return df
559560

560-
context._is_snowpark_connect_compatible_mode = True
561-
562-
df = create_snowpark_compatible_dataframe()
561+
with mock.patch(
562+
"snowflake.snowpark.context._is_snowpark_connect_compatible_mode", True
563+
):
563564

564-
def normalize(pdf):
565-
v = pdf.v
566-
return pdf.assign(v=(v - v.mean()) / v.std())
567-
568-
df = (
569-
df.group_by("id")
570-
.applyInPandas(
571-
normalize,
572-
output_schema=StructType(
573-
[
574-
StructField("id", IntegerType()),
575-
StructField("v", DoubleType()),
576-
]
577-
),
565+
df = create_snowpark_compatible_dataframe()
566+
567+
def normalize(pdf):
568+
v = pdf.v
569+
return pdf.assign(v=(v - v.mean()) / v.std())
570+
571+
df = (
572+
df.group_by("id")
573+
.applyInPandas(
574+
normalize,
575+
output_schema=StructType(
576+
[
577+
StructField("id", IntegerType()),
578+
StructField("v", DoubleType()),
579+
]
580+
),
581+
)
582+
.orderBy(["id", "v"])
578583
)
579-
.orderBy(["id", "v"])
580-
)
581-
582-
Utils.check_answer(
583-
df,
584-
[
585-
Row(ID=1, V=-0.7071067811865475),
586-
Row(ID=1, V=0.7071067811865475),
587-
Row(ID=2, V=-0.8320502943378437),
588-
Row(ID=2, V=-0.2773500981126146),
589-
Row(ID=2, V=1.1094003924504583),
590-
],
591-
)
592584

593-
df = create_snowpark_compatible_dataframe()
594-
595-
def sum_func(key, pdf):
596-
# key is a tuple of two numpy.int64s, which is the values
597-
# of 'id' and 'ceil(df.v / 2)' for the current group
598-
return pd.DataFrame([key + (pdf.v.sum(),)])
599-
600-
df = (
601-
df.group_by("id", ceil(df.v / 2).alias("newcol"))
602-
.applyInPandas(
603-
sum_func,
604-
output_schema=StructType(
605-
[
606-
StructField("id", IntegerType()),
607-
StructField("c", IntegerType()),
608-
StructField("v", DoubleType()),
609-
]
610-
),
585+
Utils.check_answer(
586+
df,
587+
[
588+
Row(ID=1, V=-0.7071067811865475),
589+
Row(ID=1, V=0.7071067811865475),
590+
Row(ID=2, V=-0.8320502943378437),
591+
Row(ID=2, V=-0.2773500981126146),
592+
Row(ID=2, V=1.1094003924504583),
593+
],
611594
)
612-
.orderBy(["id", "v"])
613-
)
614-
615-
Utils.check_answer(
616-
df,
617-
[
618-
Row(ID=1, C=1, V=3.0),
619-
Row(ID=2, C=2, V=3.0),
620-
Row(ID=2, C=3, V=5.0),
621-
Row(ID=2, C=5, V=10.0),
622-
],
623-
)
624595

625-
df = create_snowpark_compatible_dataframe()
596+
df = create_snowpark_compatible_dataframe()
597+
598+
def sum_func(key, pdf):
599+
# key is a tuple of two numpy.int64s, which is the values
600+
# of 'id' and 'ceil(df.v / 2)' for the current group
601+
return pd.DataFrame([key + (pdf.v.sum(),)])
602+
603+
df = (
604+
df.group_by("id", ceil(df.v / 2).alias("newcol"))
605+
.applyInPandas(
606+
sum_func,
607+
output_schema=StructType(
608+
[
609+
StructField("id", IntegerType()),
610+
StructField("c", IntegerType()),
611+
StructField("v", DoubleType()),
612+
]
613+
),
614+
)
615+
.orderBy(["id", "v"])
616+
)
626617

627-
def sum_func_with_single_input(pdf):
628-
# key is a tuple of two numpy.int64s, which is the values
629-
# of 'id' and 'ceil(df.v / 2)' for the current group
630-
return pd.DataFrame([(pdf.v.sum(),)])
618+
Utils.check_answer(
619+
df,
620+
[
621+
Row(ID=1, C=1, V=3.0),
622+
Row(ID=2, C=2, V=3.0),
623+
Row(ID=2, C=3, V=5.0),
624+
Row(ID=2, C=5, V=10.0),
625+
],
626+
)
631627

632-
df = (
633-
df.group_by("id", ceil(df.v / 2))
634-
.applyInPandas(
635-
sum_func_with_single_input,
636-
output_schema=StructType(
637-
[
638-
StructField("v", DoubleType()),
639-
]
640-
),
628+
df = create_snowpark_compatible_dataframe()
629+
630+
def sum_func_with_single_input(pdf):
631+
# key is a tuple of two numpy.int64s, which is the values
632+
# of 'id' and 'ceil(df.v / 2)' for the current group
633+
return pd.DataFrame([(pdf.v.sum(),)])
634+
635+
df = (
636+
df.group_by("id", ceil(df.v / 2))
637+
.applyInPandas(
638+
sum_func_with_single_input,
639+
output_schema=StructType(
640+
[
641+
StructField("v", DoubleType()),
642+
]
643+
),
644+
)
645+
.orderBy(["v"])
641646
)
642-
.orderBy(["v"])
643-
)
644647

645-
Utils.check_answer(
646-
df,
647-
[
648-
Row(V=3.0),
649-
Row(V=3.0),
650-
Row(V=5.0),
651-
Row(V=10.0),
652-
],
653-
)
648+
Utils.check_answer(
649+
df,
650+
[
651+
Row(V=3.0),
652+
Row(V=3.0),
653+
Row(V=5.0),
654+
Row(V=10.0),
655+
],
656+
)
654657

655658

656659
@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")

0 commit comments

Comments
 (0)