Skip to content

Commit bce369c

Browse files
SNOW-1926348: Fix failing tests in daily regress runner (#3131)
1 parent 70e9b2b commit bce369c

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

src/snowflake/snowpark/modin/plugin/_internal/apply_utils.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,10 @@ def end_partition(self, df): # type: ignore[no-untyped-def] # pragma: no cover
322322

323323

324324
def convert_groupby_apply_dataframe_result_to_standard_schema(
325-
func_input_df: native_pd.DataFrame,
326325
func_output_df: native_pd.DataFrame,
327326
input_row_positions: native_pd.Series,
328327
include_index_columns: bool,
328+
is_transform: bool,
329329
) -> native_pd.DataFrame: # pragma: no cover: this function runs inside a UDTF, so coverage tools can't detect that we are testing it.
330330
"""
331331
Take the result of applying the user-provided function to a dataframe, and convert it to a dataframe with known schema that we can output from a vUDTF.
@@ -338,6 +338,7 @@ def convert_groupby_apply_dataframe_result_to_standard_schema(
338338
func_input_df came from.
339339
include_index_columns: Whether to include the result's index columns in
340340
the output.
341+
is_transform: Whether the function is a transform or not.
341342
342343
Returns:
343344
A 5-column dataframe that represents the function result per the
@@ -346,7 +347,6 @@ def convert_groupby_apply_dataframe_result_to_standard_schema(
346347
"""
347348
result_rows = []
348349
result_index_names = func_output_df.index.names
349-
is_transform = func_output_df.index.equals(func_input_df.index)
350350
for row_number, (index_label, row) in enumerate(func_output_df.iterrows()):
351351
output_row_number = input_row_positions.iloc[row_number] if is_transform else -1
352352
if include_index_columns:
@@ -460,9 +460,7 @@ def apply_groupby_func_to_df(
460460
args: tuple,
461461
kwargs: dict,
462462
force_list_like_to_series: bool = False,
463-
) -> Tuple[
464-
native_pd.Series, native_pd.DataFrame, native_pd.DataFrame, bool
465-
]: # pragma: no cover
463+
) -> Tuple[native_pd.Series, native_pd.DataFrame, bool, bool]: # pragma: no cover
466464
"""
467465
Restore input dataframe received in udtf to original schema.
468466
Args:
@@ -479,9 +477,9 @@ def apply_groupby_func_to_df(
479477
Returns:
480478
A Tuple of
481479
1. rows positions
482-
2. restored input dataframe.
483-
3. Result of applying the function to input dataframe.
484-
4. Whether final result should include index columns.
480+
2. Result of applying the function to input dataframe.
481+
3. Whether final result should include index columns.
482+
4. Whether the index of the result is the same as the index of the input.
485483
"""
486484
# The first column is row position. Save it for later.
487485
col_offset = 0
@@ -568,17 +566,20 @@ def apply_groupby_func_to_df(
568566
# columns of `func_result_as_frame`. For SeriesGroupBy, we
569567
# do include the result's index in the result.
570568
include_index_columns = series_groupby
569+
is_transform = input_object.index.equals(func_result_as_frame.index)
571570
elif isinstance(func_result, native_pd.DataFrame):
572571
include_index_columns = True
573572
func_result_as_frame = func_result
573+
is_transform = input_object.index.equals(func_result_as_frame.index)
574574
else:
575575
# At this point, we know the function result was not a DataFrame
576576
# or Series
577577
include_index_columns = False
578578
func_result_as_frame = native_pd.DataFrame(
579579
{MODIN_UNNAMED_SERIES_LABEL: [func_result]}
580580
)
581-
return row_positions, input_object, func_result_as_frame, include_index_columns
581+
is_transform = False
582+
return row_positions, func_result_as_frame, include_index_columns, is_transform
582583

583584

584585
def create_udtf_for_groupby_transform(
@@ -657,7 +658,7 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def
657658
A dataframe representing the result of applying the user-provided
658659
function to this group.
659660
"""
660-
row_positions, _, func_result, _ = apply_groupby_func_to_df(
661+
row_positions, func_result, _, _ = apply_groupby_func_to_df(
661662
df,
662663
num_by,
663664
index_column_names,
@@ -879,9 +880,9 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def
879880
"""
880881
(
881882
row_positions,
882-
input_object,
883883
func_result,
884884
include_index_columns,
885+
is_transform_func,
885886
) = apply_groupby_func_to_df(
886887
df,
887888
num_by,
@@ -894,7 +895,7 @@ def end_partition(self, df: native_pd.DataFrame): # type: ignore[no-untyped-def
894895
force_list_like_to_series,
895896
)
896897
return convert_groupby_apply_dataframe_result_to_standard_schema(
897-
input_object, func_result, row_positions, include_index_columns
898+
func_result, row_positions, include_index_columns, is_transform_func
898899
)
899900

900901
input_types = [

0 commit comments

Comments
 (0)