Skip to content

Commit 41a1b90

Browse files
authored
SNOW-2148589: [Local Testing] Fix window indexing issue (#3462)
1 parent c173b87 commit 41a1b90

File tree

3 files changed

+52
-7
lines changed

3 files changed

+52
-7
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@
88

99
- Added support for row validation using XSD schema using `rowValidationXSDPath` option when reading XML files with a row tag using `rowTag` option.
1010

11+
### Snowpark Local Testing Updates
12+
13+
#### Bug Fixes
14+
15+
- Fixed a bug when processing windowed functions that lead to incorrect indexing in results.
16+
1117
## 1.33.0 (YYYY-MM-DD)
1218

1319
### Snowpark Python API Updates

src/snowflake/snowpark/mock/_plan.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2596,6 +2596,7 @@ def _match_pattern(row) -> bool:
25962596

25972597
# Process window frame specification
25982598
# Reference: https://docs.snowflake.com/en/sql-reference/functions-analytic#window-frame-usage-notes
2599+
pd_index = res_index
25992600
if not window_spec.frame_spec or not isinstance(
26002601
window_spec.frame_spec, SpecifiedWindowFrame
26012602
):
@@ -2609,13 +2610,18 @@ def _match_pattern(row) -> bool:
26092610
True,
26102611
False,
26112612
)
2613+
2614+
# Pandas reindexes the data when generating rows in a RollingGroupby
2615+
# The resulting index is not exposed in the window groupings so calculate it here
2616+
if not isinstance(windows, list):
2617+
pd_index = list(windows.count().index)
26122618
else:
26132619
indexer = EntireWindowIndexer()
26142620
rolling = res.rolling(indexer)
26152621
windows = [ordered.loc[w.index] for w in rolling]
26162622
# rolling can unpredictably change the index of the data
26172623
# apply a trivial function to materialize the final index
2618-
res_index = list(rolling.count().index)
2624+
pd_index = list(rolling.count().index)
26192625

26202626
elif isinstance(window_spec.frame_spec.frame_type, RowFrame):
26212627
indexer = RowFrameIndexer(frame_spec=window_spec.frame_spec)
@@ -2663,14 +2669,16 @@ def get_bound(bound):
26632669
# compute window function:
26642670
if isinstance(window_function, (FunctionExpression,)):
26652671
res_cols = []
2666-
for current_row, w in zip(res_index, windows):
2667-
res_cols.append(
2668-
handle_function_expression(
2669-
window_function, w, analyzer, expr_to_alias, current_row
2670-
)
2672+
2673+
for current_row, w in zip(pd_index, windows):
2674+
result = handle_function_expression(
2675+
window_function, w, analyzer, expr_to_alias, current_row
26712676
)
2677+
result.index = [current_row]
2678+
res_cols.append(result)
2679+
26722680
res_col = pd.concat(res_cols) if res_cols else ColumnEmulator([])
2673-
res_col.index = res_index
2681+
res_col.reindex(res_index)
26742682
if res_cols:
26752683
res_col.sf_type = res_cols[0].sf_type
26762684
else:

tests/mock/test_functions.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
min,
3131
rank,
3232
row_number,
33+
sum,
3334
to_char,
3435
to_date,
3536
)
@@ -512,6 +513,36 @@ def test_rank(session):
512513
)
513514

514515

516+
def test_window_indexing(session):
517+
df = session.create_dataframe(
518+
[
519+
[1, 1, 1],
520+
[2, 2, 1],
521+
[2, 2, 1],
522+
[2, 1, 1],
523+
],
524+
["A", "B", "VAL"],
525+
)
526+
527+
window_a = Window.partition_by("A")
528+
window_both = Window.partition_by("B", "A")
529+
530+
windowed = df.with_columns(
531+
["_A", "_BA"],
532+
[sum("VAL").over(window_a), sum("VAL").over(window_both)],
533+
)
534+
535+
Utils.check_answer(
536+
windowed,
537+
[
538+
Row(1, 1, 1, 1, 1),
539+
Row(2, 2, 1, 3, 2),
540+
Row(2, 2, 1, 3, 2),
541+
Row(2, 1, 1, 3, 1),
542+
],
543+
)
544+
545+
515546
def test_get(session):
516547
data = [
517548
Row(101, 1, "cat"),

0 commit comments

Comments
 (0)