Skip to content

Commit a78670a

Browse files
committed
fix tests
1 parent a5bbfbe commit a78670a

File tree

5 files changed

+207
-41
lines changed

5 files changed

+207
-41
lines changed

src/snowflake/snowpark/_internal/analyzer/select_statement.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,7 @@ def select(self, cols: List[Expression]) -> "SelectStatement":
14801480
return new
14811481

14821482
def filter(self, col: Expression) -> "SelectStatement":
1483+
self._session._retrieve_aggregation_function_list()
14831484
can_be_flattened = (
14841485
(not self.flatten_disabled)
14851486
and can_clause_dependent_columns_flatten(
@@ -1527,6 +1528,9 @@ def sort(self, cols: List[Expression]) -> "SelectStatement":
15271528
derive_dependent_columns(*cols), self.column_states, "sort"
15281529
)
15291530
and not has_data_generator_exp(self.projection)
1531+
# we do not check aggregation function here like filter
1532+
# in the case when aggregation function is in the projection
1533+
# order by is evaluated after aggregation, row info are not taken in the calculation
15301534
)
15311535
if can_be_flattened:
15321536
new = copy(self)

src/snowflake/snowpark/context.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@
3131

3232
# This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect
3333
_is_snowpark_connect_compatible_mode = False
34-
_aggregation_function_set = set()
34+
_aggregation_function_set = (
35+
set()
36+
) # lower cased names of aggregation functions, used in sql simplification
37+
_aggregation_function_set_lock = threading.RLock()
3538

3639
# Following are internal-only global flags, used to enable development features.
3740
_enable_dataframe_trace_on_error = False

src/snowflake/snowpark/session.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -521,20 +521,6 @@ def create(self) -> "Session":
521521
_add_session(session)
522522
else:
523523
session = self._create_internal(self._options.get("connection"))
524-
if context._is_snowpark_connect_compatible_mode:
525-
for sql in [
526-
"""select function_name from information_schema.functions where is_aggregate = 'YES'""",
527-
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
528-
]:
529-
try:
530-
context._aggregation_function_set.update(
531-
{r[0] for r in session.sql(sql).collect()}
532-
)
533-
except BaseException as e:
534-
_logger.debug(
535-
"Unable to get aggregation functions from the database: %s",
536-
e,
537-
)
538524

539525
if self._app_name:
540526
if self._format_json:
@@ -4939,6 +4925,31 @@ def _execute_sproc_internal(
49394925
# Note the collect is implicit within the stored procedure call, so should not emit_ast here.
49404926
return df.collect(statement_params=statement_params, _emit_ast=False)[0][0]
49414927

4928+
def _retrieve_aggregation_function_list(self) -> None:
4929+
"""Retrieve the list of aggregation functions which will later be used in sql simplifier."""
4930+
if (
4931+
not context._is_snowpark_connect_compatible_mode
4932+
or context._aggregation_function_set
4933+
):
4934+
return
4935+
4936+
retrieved_set = set()
4937+
4938+
for sql in [
4939+
"""select function_name from information_schema.functions where is_aggregate = 'YES'""",
4940+
"""show functions ->> select "name" from $1 where "is_aggregate" = 'Y'""",
4941+
]:
4942+
try:
4943+
retrieved_set.update({r[0].lower() for r in self.sql(sql).collect()})
4944+
except BaseException as e:
4945+
_logger.debug(
4946+
"Unable to get aggregation functions from the database: %s",
4947+
e,
4948+
)
4949+
4950+
with context._aggregation_function_set_lock:
4951+
context._aggregation_function_set.update(retrieved_set)
4952+
49424953
def directory(self, stage_name: str, _emit_ast: bool = True) -> DataFrame:
49434954
"""
49444955
Returns a DataFrame representing the results of a directory table query on the specified stage.

tests/integ/test_query_line_intervals.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def generate_test_data(session, sql_simplifier_enabled):
5757
}
5858

5959

60+
@pytest.mark.parametrize("snowpark_connect_compatible_mode", [True, False])
6061
@pytest.mark.parametrize(
61-
"op,sql_simplifier,line_to_expected_sql",
62+
"op,sql_simplifier,line_to_expected_sql,snowpark_connect_compatible_mode_sql",
6263
[
6364
(
6465
lambda data: data["df1"].union(data["df2"]),
@@ -68,10 +69,14 @@ def generate_test_data(session, sql_simplifier_enabled):
6869
6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)',
6970
10: 'SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (3 :: INT, \'C\' :: STRING, 300 :: INT), (4 :: INT, \'D\' :: STRING, 400 :: INT) )',
7071
},
72+
None,
7173
),
7274
(
7375
lambda data: data["df1"].filter(data["df1"].value > 150),
7476
True,
77+
{
78+
8: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)'
79+
},
7580
{
7681
8: """SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM (SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, 'A' :: STRING, 100 :: INT), (2 :: INT, 'B' :: STRING, 200 :: INT)) WHERE ("VALUE" > 150)""",
7782
},
@@ -83,6 +88,7 @@ def generate_test_data(session, sql_simplifier_enabled):
8388
1: 'SELECT "_1" AS "ID", "_2" AS "NAME" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) )',
8489
4: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)',
8590
},
91+
None,
8692
),
8793
(
8894
lambda data: data["df1"].pivot(F.col("name")).sum(F.col("value")),
@@ -92,12 +98,26 @@ def generate_test_data(session, sql_simplifier_enabled):
9298
6: 'SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT)',
9399
9: 'SELECT * FROM ( SELECT "_1" AS "ID", "_2" AS "NAME", "_3" AS "VALUE" FROM ( SELECT $1 AS "_1", $2 AS "_2", $3 AS "_3" FROM VALUES (1 :: INT, \'A\' :: STRING, 100 :: INT), (2 :: INT, \'B\' :: STRING, 200 :: INT) ) ) PIVOT ( sum("VALUE") FOR "NAME" IN ( ANY ) )',
94100
},
101+
None,
95102
),
96103
],
97104
)
98105
def test_get_plan_from_line_numbers_sql_content(
99-
session, op, sql_simplifier, line_to_expected_sql
106+
session,
107+
op,
108+
sql_simplifier,
109+
line_to_expected_sql,
110+
snowpark_connect_compatible_mode_sql,
111+
snowpark_connect_compatible_mode,
112+
monkeypatch,
100113
):
114+
if snowpark_connect_compatible_mode:
115+
import snowflake.snowpark.context as context
116+
117+
monkeypatch.setattr(context, "_is_snowpark_connect_compatible_mode", True)
118+
line_to_expected_sql = (
119+
snowpark_connect_compatible_mode_sql or line_to_expected_sql
120+
)
101121
session.sql_simplifier_enabled = sql_simplifier
102122
df = op(generate_test_data(session, sql_simplifier))
103123

0 commit comments

Comments
 (0)