Skip to content

Commit 782de21

Browse files
SNOW-1672579 Encode DataFrame.to_snowpark_pandas (#2711)
1. Which Jira issue is this PR addressing? Make sure that there is an accompanying issue to your PR. Fixes SNOW-1672579 2. Fill out the following pre-review checklist: - [x] I am adding a new automated test(s) to verify correctness of my new code - [ ] If this test skips Local Testing mode, I'm requesting review from @snowflakedb/local-testing - [ ] I am adding new logging messages - [ ] I am adding a new telemetry message - [ ] I am adding new credentials - [ ] I am adding a new dependency - [ ] If this is a new feature/behavior, I'm adding the Local Testing parity changes. - [x] I acknowledge that I have ensured my changes to be thread-safe. Follow the link for more information: [Thread-safe Developer Guidelines](https://github.com/snowflakedb/snowpark-python/blob/main/CONTRIBUTING.md#thread-safe-development) 3. Please describe how your code solves the related issue. - Added AST encoding for `DataFrame.to_snowpark_pandas`. - I had to add some Local testing functionality to help my expectation test to pass. Also needed to add some janky logic to create a temp read-only table in lieu of the table that is created by `to_snowpark_pandas`. - I updated the script generating the relevant proto files to create them in the correct directory.
1 parent 3a66c84 commit 782de21

File tree

8 files changed

+307
-61
lines changed

8 files changed

+307
-61
lines changed

src/snowflake/snowpark/_internal/proto/ast.proto

Lines changed: 58 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -793,21 +793,22 @@ message Expr {
793793
SpTableMerge sp_table_merge = 183;
794794
SpTableSample sp_table_sample = 184;
795795
SpTableUpdate sp_table_update = 185;
796-
SpWriteCopyIntoLocation sp_write_copy_into_location = 186;
797-
SpWriteCsv sp_write_csv = 187;
798-
SpWriteJson sp_write_json = 188;
799-
SpWritePandas sp_write_pandas = 189;
800-
SpWriteParquet sp_write_parquet = 190;
801-
SpWriteTable sp_write_table = 191;
802-
StoredProcedure stored_procedure = 192;
803-
StringVal string_val = 193;
804-
Sub sub = 194;
805-
TimeVal time_val = 195;
806-
TimestampVal timestamp_val = 196;
807-
TupleVal tuple_val = 197;
808-
Udaf udaf = 198;
809-
Udf udf = 199;
810-
Udtf udtf = 200;
796+
SpToSnowparkPandas sp_to_snowpark_pandas = 186;
797+
SpWriteCopyIntoLocation sp_write_copy_into_location = 187;
798+
SpWriteCsv sp_write_csv = 188;
799+
SpWriteJson sp_write_json = 189;
800+
SpWritePandas sp_write_pandas = 190;
801+
SpWriteParquet sp_write_parquet = 191;
802+
SpWriteTable sp_write_table = 192;
803+
StoredProcedure stored_procedure = 193;
804+
StringVal string_val = 194;
805+
Sub sub = 195;
806+
TimeVal time_val = 196;
807+
TimestampVal timestamp_val = 197;
808+
TupleVal tuple_val = 198;
809+
Udaf udaf = 199;
810+
Udf udf = 200;
811+
Udtf udtf = 201;
811812
}
812813
}
813814

@@ -1075,26 +1076,27 @@ message HasSrcPosition {
10751076
SpTableMerge sp_table_merge = 192;
10761077
SpTableSample sp_table_sample = 193;
10771078
SpTableUpdate sp_table_update = 194;
1078-
SpWindowSpecEmpty sp_window_spec_empty = 195;
1079-
SpWindowSpecOrderBy sp_window_spec_order_by = 196;
1080-
SpWindowSpecPartitionBy sp_window_spec_partition_by = 197;
1081-
SpWindowSpecRangeBetween sp_window_spec_range_between = 198;
1082-
SpWindowSpecRowsBetween sp_window_spec_rows_between = 199;
1083-
SpWriteCopyIntoLocation sp_write_copy_into_location = 200;
1084-
SpWriteCsv sp_write_csv = 201;
1085-
SpWriteJson sp_write_json = 202;
1086-
SpWritePandas sp_write_pandas = 203;
1087-
SpWriteParquet sp_write_parquet = 204;
1088-
SpWriteTable sp_write_table = 205;
1089-
StoredProcedure stored_procedure = 206;
1090-
StringVal string_val = 207;
1091-
Sub sub = 208;
1092-
TimeVal time_val = 209;
1093-
TimestampVal timestamp_val = 210;
1094-
TupleVal tuple_val = 211;
1095-
Udaf udaf = 212;
1096-
Udf udf = 213;
1097-
Udtf udtf = 214;
1079+
SpToSnowparkPandas sp_to_snowpark_pandas = 195;
1080+
SpWindowSpecEmpty sp_window_spec_empty = 196;
1081+
SpWindowSpecOrderBy sp_window_spec_order_by = 197;
1082+
SpWindowSpecPartitionBy sp_window_spec_partition_by = 198;
1083+
SpWindowSpecRangeBetween sp_window_spec_range_between = 199;
1084+
SpWindowSpecRowsBetween sp_window_spec_rows_between = 200;
1085+
SpWriteCopyIntoLocation sp_write_copy_into_location = 201;
1086+
SpWriteCsv sp_write_csv = 202;
1087+
SpWriteJson sp_write_json = 203;
1088+
SpWritePandas sp_write_pandas = 204;
1089+
SpWriteParquet sp_write_parquet = 205;
1090+
SpWriteTable sp_write_table = 206;
1091+
StoredProcedure stored_procedure = 207;
1092+
StringVal string_val = 208;
1093+
Sub sub = 209;
1094+
TimeVal time_val = 210;
1095+
TimestampVal timestamp_val = 211;
1096+
TupleVal tuple_val = 212;
1097+
Udaf udaf = 213;
1098+
Udf udf = 214;
1099+
Udtf udtf = 215;
10981100
}
10991101
}
11001102

@@ -1539,7 +1541,7 @@ message SpDataframeAlias {
15391541
SrcPosition src = 3;
15401542
}
15411543

1542-
// sp-df-expr.ir:464
1544+
// sp-df-expr.ir:470
15431545
message SpDataframeAnalyticsComputeLag {
15441546
repeated Expr cols = 1;
15451547
SpDataframeExpr df = 2;
@@ -1550,7 +1552,7 @@ message SpDataframeAnalyticsComputeLag {
15501552
SrcPosition src = 7;
15511553
}
15521554

1553-
// sp-df-expr.ir:473
1555+
// sp-df-expr.ir:479
15541556
message SpDataframeAnalyticsComputeLead {
15551557
repeated Expr cols = 1;
15561558
SpDataframeExpr df = 2;
@@ -1561,7 +1563,7 @@ message SpDataframeAnalyticsComputeLead {
15611563
SrcPosition src = 7;
15621564
}
15631565

1564-
// sp-df-expr.ir:455
1566+
// sp-df-expr.ir:461
15651567
message SpDataframeAnalyticsCumulativeAgg {
15661568
repeated Tuple_String_List_String aggs = 1;
15671569
SpDataframeExpr df = 2;
@@ -1572,7 +1574,7 @@ message SpDataframeAnalyticsCumulativeAgg {
15721574
SrcPosition src = 7;
15731575
}
15741576

1575-
// sp-df-expr.ir:446
1577+
// sp-df-expr.ir:452
15761578
message SpDataframeAnalyticsMovingAgg {
15771579
repeated Tuple_String_List_String aggs = 1;
15781580
SpDataframeExpr df = 2;
@@ -1583,7 +1585,7 @@ message SpDataframeAnalyticsMovingAgg {
15831585
repeated int64 window_sizes = 7;
15841586
}
15851587

1586-
// sp-df-expr.ir:482
1588+
// sp-df-expr.ir:488
15871589
message SpDataframeAnalyticsTimeSeriesAgg {
15881590
repeated Tuple_String_List_String aggs = 1;
15891591
SpDataframeExpr df = 2;
@@ -2330,21 +2332,21 @@ message SpMatchedClause {
23302332
}
23312333
}
23322334

2333-
// sp-df-expr.ir:499
2335+
// sp-df-expr.ir:505
23342336
message SpMergeDeleteWhenMatchedClause {
23352337
Expr condition = 1;
23362338
SrcPosition src = 2;
23372339
}
23382340

2339-
// sp-df-expr.ir:503
2341+
// sp-df-expr.ir:509
23402342
message SpMergeInsertWhenNotMatchedClause {
23412343
Expr condition = 1;
23422344
List_Expr insert_keys = 2;
23432345
List_Expr insert_values = 3;
23442346
SrcPosition src = 4;
23452347
}
23462348

2347-
// sp-df-expr.ir:494
2349+
// sp-df-expr.ir:500
23482350
message SpMergeUpdateWhenMatchedClause {
23492351
Expr condition = 1;
23502352
SrcPosition src = 2;
@@ -2490,7 +2492,7 @@ message SpTable {
24902492
SpTableVariant variant = 4;
24912493
}
24922494

2493-
// sp-df-expr.ir:509
2495+
// sp-df-expr.ir:515
24942496
message SpTableDelete {
24952497
bool block = 1;
24962498
Expr condition = 2;
@@ -2500,7 +2502,7 @@ message SpTableDelete {
25002502
repeated Tuple_String_String statement_params = 6;
25012503
}
25022504

2503-
// sp-df-expr.ir:517
2505+
// sp-df-expr.ir:523
25042506
message SpTableDropTable {
25052507
VarId id = 1;
25062508
SrcPosition src = 2;
@@ -2521,7 +2523,7 @@ message SpTableFnCallOver {
25212523
SrcPosition src = 4;
25222524
}
25232525

2524-
// sp-df-expr.ir:521
2526+
// sp-df-expr.ir:527
25252527
message SpTableMerge {
25262528
bool block = 1;
25272529
repeated SpMatchedClause clauses = 2;
@@ -2532,7 +2534,7 @@ message SpTableMerge {
25322534
repeated Tuple_String_String statement_params = 7;
25332535
}
25342536

2535-
// sp-df-expr.ir:530
2537+
// sp-df-expr.ir:536
25362538
message SpTableSample {
25372539
SpDataframeExpr df = 1;
25382540
google.protobuf.Int64Value num = 2;
@@ -2542,7 +2544,7 @@ message SpTableSample {
25422544
SrcPosition src = 6;
25432545
}
25442546

2545-
// sp-df-expr.ir:538
2547+
// sp-df-expr.ir:544
25462548
message SpTableUpdate {
25472549
repeated Tuple_String_Expr assignments = 1;
25482550
bool block = 2;
@@ -2553,6 +2555,14 @@ message SpTableUpdate {
25532555
repeated Tuple_String_String statement_params = 7;
25542556
}
25552557

2558+
// sp-df-expr.ir:438
2559+
message SpToSnowparkPandas {
2560+
List_String columns = 1;
2561+
SpDataframeExpr df = 2;
2562+
List_String index_col = 3;
2563+
SrcPosition src = 4;
2564+
}
2565+
25562566
message SpType {
25572567
oneof variant {
25582568
SpColExprType sp_col_expr_type = 1;

src/snowflake/snowpark/column.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1249,7 +1249,6 @@ def name(
12491249
expr = self._expression # Snowpark expression
12501250
if isinstance(expr, Alias):
12511251
expr = expr.child
1252-
12531252
ast_expr = None # Snowpark IR expression
12541253
if _emit_ast and self._ast is not None:
12551254
ast_expr = proto.Expr()

src/snowflake/snowpark/dataframe.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,22 +1241,47 @@ def to_snowpark_pandas(
12411241
# If snowflake.snowpark.modin.plugin was successfully imported, then modin.pandas is available
12421242
import modin.pandas as pd # isort: skip
12431243
# fmt: on
1244+
1245+
# AST.
1246+
stmt = None
12441247
if _emit_ast:
1245-
raise NotImplementedError(
1246-
"TODO SNOW-1672579: Support Snowpark pandas API handover."
1247-
)
1248+
stmt = self._session._ast_batch.assign()
1249+
ast = with_src_position(stmt.expr.sp_to_snowpark_pandas, stmt)
1250+
self._set_ast_ref(ast.df)
1251+
debug_check_missing_ast(self._ast_id, self)
1252+
if index_col is not None:
1253+
ast.index_col.list.extend(
1254+
index_col if isinstance(index_col, list) else [index_col]
1255+
)
1256+
if columns is not None:
1257+
ast.columns.list.extend(
1258+
columns if isinstance(columns, list) else [columns]
1259+
)
1260+
12481261
# create a temporary table out of the current snowpark dataframe
12491262
temporary_table_name = random_name_for_temp_object(
12501263
TempObjectType.TABLE
12511264
) # pragma: no cover
1265+
ast_id = self._ast_id
1266+
self._ast_id = None # set the AST ID to None to prevent AST emission.
12521267
self.write.save_as_table(
1253-
temporary_table_name, mode="errorifexists", table_type="temporary"
1268+
temporary_table_name,
1269+
mode="errorifexists",
1270+
table_type="temporary",
1271+
_emit_ast=False,
12541272
) # pragma: no cover
1273+
self._ast_id = ast_id # reset the AST ID.
12551274

12561275
snowpandas_df = pd.read_snowflake(
12571276
name_or_query=temporary_table_name, index_col=index_col, columns=columns
12581277
) # pragma: no cover
12591278

1279+
if _emit_ast:
1280+
# Set the Snowpark DataFrame AST ID to the AST ID of this pandas query.
1281+
snowpandas_df._query_compiler._modin_frame.ordered_dataframe._dataframe_ref.snowpark_dataframe._ast_id = (
1282+
stmt.var_id.bitfield1
1283+
)
1284+
12601285
return snowpandas_df
12611286

12621287
def __getitem__(self, item: Union[str, Column, List, Tuple, int]):
@@ -3904,7 +3929,7 @@ def count(
39043929
return result[0][0] if block else result
39053930

39063931
@property
3907-
def write(self, _emit_ast: bool = True) -> DataFrameWriter:
3932+
def write(self) -> DataFrameWriter:
39083933
"""Returns a new :class:`DataFrameWriter` object that you can use to write the data in the :class:`DataFrame` to
39093934
a Snowflake database or a stage location
39103935
@@ -3925,7 +3950,7 @@ def write(self, _emit_ast: bool = True) -> DataFrameWriter:
39253950
"""
39263951

39273952
# AST.
3928-
if _emit_ast and self._ast_id is not None:
3953+
if self._ast_id is not None:
39293954
stmt = self._session._ast_batch.assign()
39303955
expr = with_src_position(stmt.expr.sp_dataframe_write, stmt)
39313956
self._set_ast_ref(expr.df)

src/snowflake/snowpark/mock/_analyzer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -828,10 +828,10 @@ def do_resolve_with_resolved_children(
828828
)
829829

830830
if isinstance(logical_plan, Project):
831-
return logical_plan
831+
return MockExecutionPlan(logical_plan, self.session)
832832

833833
if isinstance(logical_plan, Filter):
834-
return logical_plan
834+
return MockExecutionPlan(logical_plan, self.session)
835835

836836
# Add a sample stop to the plan being built
837837
if isinstance(logical_plan, Sample):
@@ -895,6 +895,9 @@ def do_resolve_with_resolved_children(
895895
if isinstance(logical_plan, SnowflakeCreateTable):
896896
return MockExecutionPlan(logical_plan, self.session)
897897

898+
if isinstance(logical_plan, SnowflakePlan):
899+
return MockExecutionPlan(logical_plan, self.session)
900+
898901
if isinstance(logical_plan, Limit):
899902
on_top_of_order_by = isinstance(
900903
logical_plan.child, SnowflakePlan

src/snowflake/snowpark/mock/_plan.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
CreateViewCommand,
136136
Pivot,
137137
Sample,
138+
Project,
138139
)
139140
from snowflake.snowpark._internal.type_utils import infer_type
140141
from snowflake.snowpark._internal.utils import (
@@ -1289,6 +1290,8 @@ def aggregate_by_groups(cur_group: TableEmulator):
12891290
dtype=object,
12901291
)
12911292
return result_df
1293+
if isinstance(source_plan, Project):
1294+
return TableEmulator(ColumnEmulator(col) for col in source_plan.project_list)
12921295
if isinstance(source_plan, Join):
12931296
L_expr_to_alias = {}
12941297
R_expr_to_alias = {}
@@ -1450,6 +1453,19 @@ def outer_join(base_df):
14501453

14511454
obj_name_tuple = parse_table_name(entity_name)
14521455
obj_name = obj_name_tuple[-1]
1456+
1457+
# Logic to create a read-only temp table for AST testing purposes.
1458+
# Functions like to_snowpark_pandas create a clone of an existing table as a read-only table that is referenced
1459+
# during testing.
1460+
if "SNOWPARK_TEMP_TABLE" in obj_name and "READONLY" in obj_name:
1461+
# Create the read-only temp table.
1462+
entity_registry.write_table(
1463+
obj_name,
1464+
TableEmulator({"A": [1], "B": [1], "C": [1]}),
1465+
SaveMode.IGNORE,
1466+
)
1467+
return entity_registry.read_table_if_exists(obj_name)
1468+
14531469
obj_schema = (
14541470
obj_name_tuple[-2]
14551471
if len(obj_name_tuple) > 1

src/snowflake/snowpark/modin/plugin/_internal/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -313,8 +313,9 @@ def _create_read_only_table(
313313
)
314314
# TODO (SNOW-1669224): pushing read only table creation down to snowpark for general usage
315315
session.sql(
316-
f"CREATE OR REPLACE {get_temp_type_for_object(use_scoped_temp_objects=use_scoped_temp_table, is_generated=True)} READ ONLY TABLE {readonly_table_name} CLONE {table_name}"
317-
).collect(statement_params=statement_params)
316+
f"CREATE OR REPLACE {get_temp_type_for_object(use_scoped_temp_objects=use_scoped_temp_table, is_generated=True)} READ ONLY TABLE {readonly_table_name} CLONE {table_name}",
317+
_emit_ast=False,
318+
).collect(statement_params=statement_params, _emit_ast=False)
318319

319320
return readonly_table_name
320321

@@ -389,7 +390,7 @@ def create_ordered_dataframe_with_readonly_temp_table(
389390
error_code=SnowparkPandasErrorCode.GENERAL_SQL_EXCEPTION.value,
390391
) from ex
391392
initial_ordered_dataframe = OrderedDataFrame(
392-
DataFrameReference(session.table(readonly_table_name))
393+
DataFrameReference(session.table(readonly_table_name, _emit_ast=False))
393394
)
394395
# generate a snowflake quoted identifier for row position column that can be used for aliasing
395396
snowflake_quoted_identifiers = (
@@ -415,7 +416,7 @@ def create_ordered_dataframe_with_readonly_temp_table(
415416
# with the created snowpark dataframe. In order to get the metadata column access in the created
416417
# dataframe, we create dataframe through sql which access the corresponding metadata column.
417418
dataframe_sql = f"SELECT {columns_to_select} FROM {readonly_table_name}"
418-
snowpark_df = session.sql(dataframe_sql)
419+
snowpark_df = session.sql(dataframe_sql, _emit_ast=False)
419420

420421
result_columns_quoted_identifiers = [
421422
row_position_snowflake_quoted_identifier

0 commit comments

Comments
 (0)