Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 114 additions & 16 deletions src/snowflake/snowpark/_internal/analyzer/analyzer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@
ICEBERG_VERSION = "ICEBERG_VERSION"
RENAME_FIELDS = " RENAME FIELDS"
ADD_FIELDS = " ADD FIELDS"
NEW_LINE = "\n"
TAB = " "

TEMPORARY_STRING_SET = frozenset(["temporary", "temp"])

Expand Down Expand Up @@ -266,7 +268,9 @@ def partition_spec(col_exprs: List[str]) -> str:


def order_by_spec(col_exprs: List[str]) -> str:
return f" ORDER BY {COMMA.join(col_exprs)}" if col_exprs else EMPTY_STRING
if not col_exprs:
return EMPTY_STRING
return ORDER_BY + NEW_LINE + TAB + (COMMA + NEW_LINE + TAB).join(col_exprs)


def table_function_partition_spec(
Expand Down Expand Up @@ -297,7 +301,9 @@ def within_group_expression(column: str, order_by_cols: List[str]) -> str:
+ WITHIN_GROUP
+ LEFT_PARENTHESIS
+ ORDER_BY
+ COMMA.join(order_by_cols)
+ NEW_LINE
+ TAB
+ (COMMA + NEW_LINE + TAB).join(order_by_cols)
+ RIGHT_PARENTHESIS
)

Expand Down Expand Up @@ -384,11 +390,15 @@ def lateral_statement(lateral_expression: str, child: str) -> str:
return (
SELECT
+ STAR
+ NEW_LINE
+ FROM
+ LEFT_PARENTHESIS
+ NEW_LINE
+ child
+ NEW_LINE
+ RIGHT_PARENTHESIS
+ COMMA
+ NEW_LINE
+ LATERAL
+ lateral_expression
)
Expand All @@ -414,18 +424,25 @@ def join_table_function_statement(

left_cols = [f"{LEFT_ALIAS}.{col}" for col in left_cols]
right_cols = [f"{RIGHT_ALIAS}.{col}" for col in right_cols]
select_cols = COMMA.join(left_cols + right_cols)
select_cols = (COMMA + NEW_LINE + TAB).join(left_cols + right_cols)

return (
SELECT
+ NEW_LINE
+ TAB
+ select_cols
+ NEW_LINE
+ FROM
+ LEFT_PARENTHESIS
+ NEW_LINE
+ child
+ NEW_LINE
+ RIGHT_PARENTHESIS
+ AS
+ LEFT_ALIAS
+ NEW_LINE
+ JOIN
+ NEW_LINE
+ table(func)
+ AS
+ RIGHT_ALIAS
Expand All @@ -451,19 +468,27 @@ def case_when_expression(branches: List[Tuple[str, str]], else_value: str) -> st


def project_statement(project: List[str], child: str, is_distinct: bool = False) -> str:
if not project:
columns = STAR
else:
columns = NEW_LINE + TAB + (COMMA + NEW_LINE + TAB).join(project)

return (
SELECT
+ f"{DISTINCT if is_distinct else EMPTY_STRING}"
+ f"{STAR if not project else COMMA.join(project)}"
+ columns
+ NEW_LINE
+ FROM
+ LEFT_PARENTHESIS
+ NEW_LINE
+ child
+ NEW_LINE
+ RIGHT_PARENTHESIS
)


def filter_statement(condition: str, child: str) -> str:
return project_statement([], child) + WHERE + condition
return project_statement([], child) + NEW_LINE + WHERE + condition


def sample_statement(
Expand Down Expand Up @@ -500,7 +525,18 @@ def sample_by_statement(child: str, col: str, fractions: Dict[Any, float]) -> st
PERCENT_RANK_COL = random_name_for_temp_object(TempObjectType.COLUMN)
LEFT_ALIAS = "SNOWPARK_LEFT"
RIGHT_ALIAS = "SNOWPARK_RIGHT"
child_with_percentage_rank_stmt = f"SELECT *, PERCENT_RANK() OVER (PARTITION BY {col} ORDER BY RANDOM()) AS {PERCENT_RANK_COL} FROM ({child})"
child_with_percentage_rank_stmt = (
SELECT
+ STAR
+ COMMA
+ f"PERCENT_RANK() OVER (PARTITION BY {col} ORDER BY RANDOM()) AS {PERCENT_RANK_COL}"
+ FROM
+ LEFT_PARENTHESIS
+ NEW_LINE
+ child
+ NEW_LINE
+ RIGHT_PARENTHESIS
)

# PERCENT_RANK assigns values between 0.0 - 1.0 both inclusive. In our, query we only
# select values where percent_rank <= value. If value = 0, then we will select one sample
Expand All @@ -510,10 +546,28 @@ def sample_by_statement(child: str, col: str, fractions: Dict[Any, float]) -> st
fraction_flatten_stmt = f"SELECT KEY, VALUE FROM TABLE(FLATTEN(input => parse_json('{json.dumps(updated_fractions)}')))"

return (
f"{SELECT} {LEFT_ALIAS}.* EXCLUDE {PERCENT_RANK_COL} {FROM} ({child_with_percentage_rank_stmt}) {LEFT_ALIAS}"
f"{JOIN} ({fraction_flatten_stmt}) {RIGHT_ALIAS}"
f"{ON} {LEFT_ALIAS}.{col} = {RIGHT_ALIAS}.KEY"
f"{WHERE} {LEFT_ALIAS}.{PERCENT_RANK_COL} <= {RIGHT_ALIAS}.VALUE"
SELECT
+ f"{LEFT_ALIAS}.* EXCLUDE {PERCENT_RANK_COL}"
+ FROM
+ LEFT_PARENTHESIS
+ NEW_LINE
+ child_with_percentage_rank_stmt
+ NEW_LINE
+ RIGHT_PARENTHESIS
+ AS
+ LEFT_ALIAS
+ JOIN
+ LEFT_PARENTHESIS
+ NEW_LINE
+ fraction_flatten_stmt
+ NEW_LINE
+ RIGHT_PARENTHESIS
+ AS
+ RIGHT_ALIAS
+ ON
+ f"{LEFT_ALIAS}.{col} = {RIGHT_ALIAS}.KEY"
+ WHERE
+ f"{LEFT_ALIAS}.{PERCENT_RANK_COL} <= {RIGHT_ALIAS}.VALUE"
)


Expand All @@ -525,12 +579,25 @@ def aggregate_statement(
return project_statement(aggregate_exprs, child) + (
limit_expression(1)
if not grouping_exprs
else (GROUP_BY + COMMA.join(grouping_exprs))
else (
NEW_LINE
+ GROUP_BY
+ NEW_LINE
+ TAB
+ (COMMA + NEW_LINE + TAB).join(grouping_exprs)
)
)


def sort_statement(order: List[str], child: str) -> str:
return project_statement([], child) + ORDER_BY + COMMA.join(order)
return (
project_statement([], child)
+ NEW_LINE
+ ORDER_BY
+ NEW_LINE
+ TAB
+ (COMMA + NEW_LINE + TAB).join(order)
)


def range_statement(start: int, end: int, step: int, column_name: str) -> str:
Expand Down Expand Up @@ -773,13 +840,16 @@ def snowflake_supported_join_statement(
+ AS
+ left_alias
+ SPACE
+ NEW_LINE
+ join_sql
+ JOIN
+ NEW_LINE
+ LEFT_PARENTHESIS
+ right
+ RIGHT_PARENTHESIS
+ AS
+ right_alias
+ NEW_LINE
+ f"{match_condition if match_condition else EMPTY_STRING}"
+ f"{using_condition if using_condition else EMPTY_STRING}"
+ f"{join_condition if join_condition else EMPTY_STRING}"
Expand Down Expand Up @@ -1284,10 +1354,15 @@ def pivot_statement(
+ select_str
+ FROM
+ LEFT_PARENTHESIS
+ NEW_LINE
+ child
+ NEW_LINE
+ RIGHT_PARENTHESIS
+ NEW_LINE
+ PIVOT
+ LEFT_PARENTHESIS
+ NEW_LINE
+ TAB
+ aggregate
+ FOR
+ pivot_column
Expand All @@ -1298,6 +1373,7 @@ def pivot_statement(
if default_on_null
else EMPTY_STRING
)
+ NEW_LINE
+ RIGHT_PARENTHESIS
)

Expand All @@ -1314,18 +1390,24 @@ def unpivot_statement(
+ STAR
+ FROM
+ LEFT_PARENTHESIS
+ NEW_LINE
+ child
+ NEW_LINE
+ RIGHT_PARENTHESIS
+ NEW_LINE
+ UNPIVOT
+ (INCLUDE_NULLS if include_nulls else EMPTY_STRING)
+ LEFT_PARENTHESIS
+ NEW_LINE
+ TAB
+ value_column
+ FOR
+ name_column
+ IN
+ LEFT_PARENTHESIS
+ COMMA.join(column_list)
+ RIGHT_PARENTHESIS
+ NEW_LINE
+ RIGHT_PARENTHESIS
)

Expand All @@ -1336,11 +1418,16 @@ def rename_statement(column_map: Dict[str, str], child: str) -> str:
+ STAR
+ RENAME
+ LEFT_PARENTHESIS
+ NEW_LINE
+ TAB
+ COMMA.join([f"{before}{AS}{after}" for before, after in column_map.items()])
+ NEW_LINE
+ RIGHT_PARENTHESIS
+ FROM
+ LEFT_PARENTHESIS
+ NEW_LINE
+ child
+ NEW_LINE
+ RIGHT_PARENTHESIS
)

Expand Down Expand Up @@ -1422,7 +1509,11 @@ def copy_into_table(
+ column_str
+ FROM
+ from_str
+ (PATTERN + EQUALS + single_quote(pattern) if pattern else EMPTY_STRING)
+ (
NEW_LINE + PATTERN + EQUALS + single_quote(pattern)
if pattern
else EMPTY_STRING
)
+ files_str
+ ftostr
+ costr
Expand Down Expand Up @@ -1590,10 +1681,15 @@ def merge_statement(
+ table_name
+ USING
+ LEFT_PARENTHESIS
+ NEW_LINE
+ TAB
+ source
+ NEW_LINE
+ RIGHT_PARENTHESIS
+ NEW_LINE
+ ON
+ join_expr
+ NEW_LINE
+ EMPTY_STRING.join(clauses)
)

Expand Down Expand Up @@ -1873,9 +1969,11 @@ def drop_object(name: str, object_type: str) -> None:
target_table_location = build_location_helper(
database,
schema,
random_name_for_temp_object(TempObjectType.TABLE)
if (overwrite and auto_create_table)
else table_name,
(
random_name_for_temp_object(TempObjectType.TABLE)
if (overwrite and auto_create_table)
else table_name
),
quote_identifiers,
)

Expand Down
Loading