Skip to content

Commit 0b08ec5

Browse files
Add a functionality in apply_in_pandas to support spark api (#3162)
1 parent 6c1ea54 commit 0b08ec5

File tree

3 files changed

+83
-2
lines changed

3 files changed

+83
-2
lines changed

src/snowflake/snowpark/context.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
_use_structured_type_semantics = False
2727
_use_structured_type_semantics_lock = threading.RLock()
2828

29+
# This is an internal-only global flag, used to determine whether the api code which will be executed is compatible with snowflake.snowpark_connect
30+
_is_snowpark_connect_compatible_mode = False
31+
2932

3033
def _should_use_structured_type_semantics():
3134
global _use_structured_type_semantics

src/snowflake/snowpark/relational_grouped_dataframe.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
# Copyright (c) 2012-2025 Snowflake Computing Inc. All rights reserved.
44
#
55
from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union
6+
import inspect
67

78
import snowflake.snowpark._internal.proto.generated.ast_pb2 as proto
9+
import snowflake.snowpark.context as context
810
from snowflake.connector.options import pandas
11+
from snowflake.snowpark._internal.analyzer.analyzer_utils import unquote_if_quoted
912
from snowflake.snowpark import functions
1013
from snowflake.snowpark._internal.analyzer.expression import (
1114
Expression,
@@ -404,8 +407,36 @@ def apply_in_pandas(
404407
- :func:`~snowflake.snowpark.functions.pandas_udtf`
405408
"""
406409

410+
partition_by = [Column(expr, _emit_ast=False) for expr in self._grouping_exprs]
411+
412+
# this is the case where this is being called from spark
413+
# this is not handleing nested column access, it is assuming that the access in the function is not nested
414+
original_columns: List[str] | None = None
415+
key_columns: List[str] | None = None
416+
if context._is_snowpark_connect_compatible_mode:
417+
if self._dataframe._column_map is not None:
418+
original_columns = [
419+
column.spark_name for column in self._dataframe._column_map.columns
420+
]
421+
signature = inspect.signature(func)
422+
parameters = signature.parameters
423+
if len(parameters) == 2:
424+
key_columns = [
425+
unquote_if_quoted(col.get_name()) for col in partition_by
426+
]
427+
407428
class _ApplyInPandas:
408429
def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame:
430+
if key_columns is not None:
431+
import numpy as np
432+
433+
key_list = [pdf[key].iloc[0] for key in key_columns]
434+
numpy_array = np.array(key_list)
435+
keys = tuple(numpy_array)
436+
if original_columns is not None:
437+
pdf.columns = original_columns
438+
if key_columns is not None:
439+
return func(keys, pdf)
409440
return func(pdf)
410441

411442
# for vectorized UDTF
@@ -427,7 +458,6 @@ def end_partition(self, pdf: pandas.DataFrame) -> pandas.DataFrame:
427458
_emit_ast=_emit_ast,
428459
**kwargs,
429460
)
430-
partition_by = [Column(expr, _emit_ast=False) for expr in self._grouping_exprs]
431461

432462
df = self._dataframe.select(
433463
_apply_in_pandas_udtf(*self._dataframe.columns).over(

tests/integ/test_udtf.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import pytest
1111

12-
from snowflake.snowpark import Row, Table
12+
from snowflake.snowpark import Row, Table, context
1313
from snowflake.snowpark._internal.utils import TempObjectType
1414
from snowflake.snowpark.exceptions import SnowparkSQLException
1515
from snowflake.snowpark.functions import lit, udtf
@@ -532,6 +532,54 @@ def group_sum(pdf):
532532
],
533533
)
534534

535+
class Column:
536+
def __init__(self, spark_name: str) -> None:
537+
self.spark_name = spark_name
538+
539+
class ColumnMap:
540+
def __init__(self) -> None:
541+
self.columns: List[Column] = []
542+
543+
# test with multiple columns in group by
544+
df = session.createDataFrame(
545+
[(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], ("id", "v")
546+
)
547+
548+
# this is to mock the current behavior
549+
df._column_map = ColumnMap()
550+
df._column_map.columns = [Column("id"), Column("v")]
551+
552+
context._is_snowpark_connect_compatible_mode = True
553+
554+
def normalize(pdf):
555+
v = pdf.v
556+
return pdf.assign(v=(v - v.mean()) / v.std())
557+
558+
df = (
559+
df.group_by("id")
560+
.applyInPandas(
561+
normalize,
562+
output_schema=StructType(
563+
[
564+
StructField("id", IntegerType()),
565+
StructField("v", DoubleType()),
566+
]
567+
),
568+
)
569+
.orderBy(["id", "v"])
570+
)
571+
572+
Utils.check_answer(
573+
df,
574+
[
575+
Row(ID=1, V=-0.7071067811865475),
576+
Row(ID=1, V=0.7071067811865475),
577+
Row(ID=2, V=-0.8320502943378437),
578+
Row(ID=2, V=-0.2773500981126146),
579+
Row(ID=2, V=1.1094003924504583),
580+
],
581+
)
582+
535583

536584
@pytest.mark.skipif(IS_IN_STORED_PROC, reason="Cannot create session in SP")
537585
def test_permanent_udtf_negative(session, db_parameters):

0 commit comments

Comments
 (0)