Skip to content

Commit 48e715d

Browse files
authored
SNOW-2752334: Fix overlap handling when parsing XML file (#4008)
1 parent 2a377bb commit 48e715d

File tree

4 files changed

+60
-20
lines changed

4 files changed

+60
-20
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@
5353

5454
- Catalog API no longer uses types declared in `snowflake.core` and therefore this dependency was removed.
5555

56+
#### Bug Fixes
57+
58+
- Fixed a bug in `XMLReader` where finding the start position of a row tag could return an incorrect file position.
59+
5660
### Snowpark pandas API Updates
5761

5862
#### New Features

src/snowflake/snowpark/_internal/xml_reader.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,6 @@ def find_next_opening_tag_pos(
205205
chunk = file_obj.read(current_chunk_size)
206206
if not chunk:
207207
raise EOFError("Reached end of file before finding opening tag")
208-
# If the chunk is smaller than expected, we are near the end.
209-
if len(chunk) < current_chunk_size:
210-
if chunk.find(tag_start_1) == -1 and chunk.find(tag_start_2) == -1:
211-
raise EOFError("Reached end of file before finding opening tag")
212208

213209
# Combine leftover from previous read with the new chunk.
214210
data = overlap + chunk
@@ -233,9 +229,6 @@ def find_next_opening_tag_pos(
233229
# Update the overlap from the end of the combined data.
234230
overlap = data[-overlap_size:] if len(data) >= overlap_size else data
235231

236-
# Otherwise, rewind by the length of the overlap so that a tag spanning the boundary isn't missed.
237-
file_obj.seek(-len(overlap), 1)
238-
239232
# Check that progress is being made to avoid infinite loops.
240233
if file_obj.tell() <= pos_before:
241234
raise EOFError("No progress made while searching for opening tag")

tests/integ/scala/test_dataframe_aggregate_suite.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ def test_group_by_grouping_sets(session):
591591
.with_column("medical_license", lit(None))
592592
.select("medical_license", "radio_license", "count")
593593
)
594-
.sort(col("count"))
594+
.sort(col("count"), col("radio_license"))
595595
.collect()
596596
)
597597

@@ -601,16 +601,16 @@ def test_group_by_grouping_sets(session):
601601
GroupingSets([col("medical_license")], [col("radio_license")])
602602
)
603603
.agg(count(col("*")).as_("count"))
604-
.sort(col("count"))
604+
.sort(col("count"), col("radio_license"))
605605
)
606606

607607
Utils.check_answer(grouping_sets, result, sort=False)
608608

609609
Utils.check_answer(
610610
grouping_sets,
611611
[
612-
Row(None, "General", 1),
613612
Row(None, "Amateur Extra", 1),
613+
Row(None, "General", 1),
614614
Row("RN", None, 2),
615615
Row(None, "Technician", 2),
616616
Row(None, None, 3),
@@ -624,8 +624,8 @@ def test_group_by_grouping_sets(session):
624624
TestData.nurse(session)
625625
.group_by("medical_license", "radio_license")
626626
.agg(count(col("*")).as_("count"))
627-
.sort(col("count"), col("medical_license"), col("radio_license"))
628-
.select("count", "medical_license", "radio_license"),
627+
.select("count", "medical_license", "radio_license")
628+
.sort(col("count"), col("medical_license"), col("radio_license")),
629629
[
630630
Row(1, "LVN", "General"),
631631
Row(1, "RN", None),
@@ -775,11 +775,13 @@ def test_rel_grouped_dataframe_median(session):
775775
def test_builtin_functions(session):
776776
df = session.create_dataframe([(1, 11), (2, 12), (1, 13)]).to_df(["a", "b"])
777777

778-
assert df.group_by("a").builtin("max")(col("a"), col("b")).collect() == [
778+
assert df.group_by("a").builtin("max")(col("a"), col("b")).sort(
779+
col("a")
780+
).collect() == [
779781
Row(1, 1, 13),
780782
Row(2, 2, 12),
781783
]
782-
assert df.group_by("a").builtin("max")(col("b")).collect() == [
784+
assert df.group_by("a").builtin("max")(col("b")).sort(col("a")).collect() == [
783785
Row(1, 13),
784786
Row(2, 12),
785787
]
@@ -828,16 +830,16 @@ def test_non_empty_arg_functions(session):
828830

829831

830832
def test_null_count(session):
831-
assert TestData.test_data3(session).group_by("a").agg(
832-
count(col("b"))
833+
assert TestData.test_data3(session).group_by("a").agg(count(col("b"))).sort(
834+
col("a")
833835
).collect() == [
834836
Row(1, 0),
835837
Row(2, 1),
836838
]
837839

838840
assert TestData.test_data3(session).group_by("a").agg(
839841
count(col("a") + col("b"))
840-
).collect() == [Row(1, 0), Row(2, 1)]
842+
).sort(col("a")).collect() == [Row(1, 0), Row(2, 1)]
841843

842844
assert TestData.test_data3(session).agg(
843845
[
@@ -1147,9 +1149,12 @@ def test_ints_in_agg_exprs_are_taken_as_groupby_ordinal(session):
11471149
[lit(6), lit(7), sum(col("b"))]
11481150
).collect() == [Row(3, 4, 6, 7, 9)]
11491151

1150-
assert TestData.test_data2(session).group_by([lit(3), lit(4)]).agg(
1151-
[lit(6), col("b"), sum(col("b"))]
1152-
).collect() == [Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6)]
1152+
Utils.check_answer(
1153+
TestData.test_data2(session)
1154+
.group_by([lit(3), lit(4)])
1155+
.agg([lit(6), col("b"), sum(col("b"))]),
1156+
[Row(3, 4, 6, 1, 3), Row(3, 4, 6, 2, 6)],
1157+
)
11531158

11541159

11551160
@pytest.mark.xfail(

tests/unit/test_xml_reader.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,44 @@ def test_find_next_opening_tag_pos_normal(chunk_size):
389389
assert pos == expected_pos
390390

391391

392+
@pytest.mark.parametrize("chunk_size", [10, 100, DEFAULT_CHUNK_SIZE])
393+
def test_find_next_opening_tag_pos_full_chunk_before_tag(chunk_size):
394+
# This tests that the overlap logic works correctly when multiple chunks
395+
# must be read before finding the tag.
396+
prefix = b"x" * (chunk_size * 2 + 10) # More than 2 full chunks
397+
record = prefix + b"<row attr='value'> more content here </row>"
398+
file_obj = io.BytesIO(record)
399+
tag_start_1 = b"<row>"
400+
tag_start_2 = b"<row "
401+
end_limit = len(record)
402+
pos = find_next_opening_tag_pos(
403+
file_obj, tag_start_1, tag_start_2, end_limit, chunk_size=chunk_size
404+
)
405+
# Should find the first tag after all the prefix data
406+
expected_pos = len(prefix)
407+
assert pos == expected_pos
408+
# Verify file pointer is at the correct position
409+
assert file_obj.tell() == expected_pos
410+
411+
412+
@pytest.mark.parametrize("chunk_size", [10, 100, DEFAULT_CHUNK_SIZE])
413+
def test_find_next_opening_tag_pos_tag_spans_chunk_boundary(chunk_size):
414+
# Position the tag so it splits exactly across a chunk boundary.
415+
# This is the most challenging case for the overlap logic.
416+
# Place the tag start 2 bytes before the chunk boundary
417+
prefix = b"x" * (chunk_size - 2)
418+
record = prefix + b"<row attr='value'> content </row>"
419+
file_obj = io.BytesIO(record)
420+
tag_start_1 = b"<row>"
421+
tag_start_2 = b"<row "
422+
end_limit = len(record)
423+
pos = find_next_opening_tag_pos(
424+
file_obj, tag_start_1, tag_start_2, end_limit, chunk_size=chunk_size
425+
)
426+
expected_pos = len(prefix)
427+
assert pos == expected_pos
428+
429+
392430
@pytest.mark.parametrize("chunk_size", [3, 10, DEFAULT_CHUNK_SIZE])
393431
def test_find_next_opening_tag_pos_both_variants(chunk_size):
394432
# Test when both "<row>" and "<row " exist.

0 commit comments

Comments
 (0)