Skip to content

Commit 28fbb25

Browse files
authored
Add missing stream synchronizations to various tests (#21122)
I noticed these issues when investigating #21094. I suspect that this was the underlying issue behind #19900. Authors: - Vyas Ramasubramani (https://github.com/vyasr) - Matthew Murray (https://github.com/Matt711) Approvers: - Matthew Murray (https://github.com/Matt711) URL: #21122
1 parent 1b09b30 commit 28fbb25

File tree

7 files changed

+68
-8
lines changed

7 files changed

+68
-8
lines changed

python/pylibcudf/tests/common/utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33
from __future__ import annotations
44

@@ -16,6 +16,20 @@
1616
from pylibcudf.io.types import CompressionType
1717

1818

19+
def synchronize_stream(stream=None):
20+
"""Synchronize a stream, handling both explicit streams and None (default stream).
21+
22+
Parameters
23+
----------
24+
stream : Stream or None
25+
The stream to synchronize. If None, synchronizes the default stream.
26+
"""
27+
if stream is None:
28+
plc.utils.DEFAULT_STREAM.synchronize()
29+
else:
30+
stream.synchronize()
31+
32+
1933
def metadata_from_arrow_type(
2034
pa_type: pa.Array,
2135
name: str = "",

python/pylibcudf/tests/io/test_avro.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33

44
import io
@@ -7,7 +7,7 @@
77
import fastavro
88
import pyarrow as pa
99
import pytest
10-
from utils import assert_table_and_meta_eq
10+
from utils import assert_table_and_meta_eq, synchronize_stream
1111

1212
from rmm.pylibrmm.device_buffer import DeviceBuffer
1313
from rmm.pylibrmm.stream import Stream
@@ -156,6 +156,8 @@ def test_read_avro_from_device_buffers(avro_dtypes, avro_dtype_data, stream):
156156
buf = buffer.getbuffer()
157157
device_buf = DeviceBuffer.to_device(buf, plc.utils._get_stream(stream))
158158

159+
synchronize_stream(stream)
160+
159161
options = plc.io.avro.AvroReaderOptions.builder(
160162
plc.io.types.SourceInfo([device_buf])
161163
).build()

python/pylibcudf/tests/io/test_csv.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33
import io
44
import os
@@ -11,6 +11,7 @@
1111
assert_table_and_meta_eq,
1212
make_source,
1313
sink_to_str,
14+
synchronize_stream,
1415
write_source_str,
1516
)
1617

@@ -314,6 +315,8 @@ def test_read_csv_from_device_buffers(csv_table_data, stream):
314315
csv_string.encode("utf-8"), plc.utils._get_stream(stream)
315316
)
316317

318+
synchronize_stream(stream)
319+
317320
options = plc.io.csv.CsvReaderOptions.builder(
318321
plc.io.SourceInfo([buf])
319322
).build()
@@ -379,6 +382,8 @@ def test_write_csv(
379382
stream,
380383
)
381384

385+
synchronize_stream(stream)
386+
382387
# Convert everything to string to make comparisons easier
383388
str_result = sink_to_str(sink)
384389

@@ -423,6 +428,8 @@ def test_write_csv_na_rep(na_rep):
423428
)
424429
)
425430

431+
synchronize_stream()
432+
426433
# Convert everything to string to make comparisons easier
427434
str_result = sink_to_str(sink)
428435

python/pylibcudf/tests/io/test_experimental_hybrid_scan.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33
import io
44

55
import pyarrow as pa
66
import pyarrow.parquet as pq
77
import pytest
8+
from utils import synchronize_stream
89

910
from rmm import DeviceBuffer
1011
from rmm.pylibrmm.stream import Stream
@@ -327,6 +328,8 @@ def test_hybrid_scan_materialize_columns(
327328
for r in filter_ranges
328329
]
329330

331+
synchronize_stream(stream)
332+
330333
# Materialize filter columns (mr is optional, defaults to None)
331334
filter_result = simple_hybrid_scan_reader.materialize_filter_columns(
332335
filtered_row_groups,
@@ -337,6 +340,8 @@ def test_hybrid_scan_materialize_columns(
337340
stream,
338341
)
339342

343+
synchronize_stream(stream)
344+
340345
# Filter column should have 1 column, with rows passing the filter
341346
expected_result_rows = num_rows - filter_threshold
342347
assert filter_result.tbl.num_columns() == 1
@@ -356,6 +361,8 @@ def test_hybrid_scan_materialize_columns(
356361
for r in payload_ranges
357362
]
358363

364+
synchronize_stream(stream)
365+
359366
# Materialize payload columns (mr is optional, defaults to None)
360367
payload_result = simple_hybrid_scan_reader.materialize_payload_columns(
361368
filtered_row_groups,
@@ -366,6 +373,8 @@ def test_hybrid_scan_materialize_columns(
366373
stream,
367374
)
368375

376+
synchronize_stream(stream)
377+
369378
assert payload_result.tbl.num_columns() == 2
370379
assert payload_result.tbl.num_rows() == expected_result_rows
371380

@@ -379,6 +388,8 @@ def test_hybrid_scan_materialize_columns(
379388
comparison_options.set_filter(filter_expression)
380389
expected_result = plc.io.parquet.read_parquet(comparison_options, stream)
381390

391+
synchronize_stream(stream)
392+
382393
# Combine hybrid scan results
383394
hybrid_columns = filter_result.tbl.columns() + payload_result.tbl.columns()
384395
hybrid_table = plc.Table(hybrid_columns)
@@ -437,6 +448,8 @@ def test_hybrid_scan_has_next_table_chunk(
437448
for r in filter_ranges
438449
]
439450

451+
synchronize_stream()
452+
440453
# Setup chunking first
441454
simple_hybrid_scan_reader.setup_chunking_for_filter_columns(
442455
512, # chunk_read_limit
@@ -503,6 +516,8 @@ def test_hybrid_scan_chunked_reading(
503516
for r in filter_ranges
504517
]
505518

519+
synchronize_stream(stream)
520+
506521
# Setup chunking for filter columns with small chunk size
507522
chunk_read_limit = 512 # Small limit to force multiple chunks
508523
pass_read_limit = 0 # No limit

python/pylibcudf/tests/io/test_json.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33
import io
44

@@ -9,6 +9,7 @@
99
assert_table_and_meta_eq,
1010
make_source,
1111
sink_to_str,
12+
synchronize_stream,
1213
write_source_str,
1314
)
1415

@@ -44,6 +45,8 @@ def test_write_json_basic(
4445

4546
plc.io.json.write_json(options, stream)
4647

48+
synchronize_stream(stream)
49+
4750
exp = pa_table.to_pandas()
4851

4952
# Convert everything to string to make
@@ -82,6 +85,8 @@ def test_write_json_nulls(na_rep, include_nulls):
8285

8386
plc.io.json.write_json(options)
8487

88+
synchronize_stream()
89+
8590
exp = pa_tbl.to_pandas()
8691

8792
# Convert everything to string to make
@@ -133,6 +138,8 @@ def test_write_json_bool_opts(true_value, false_value):
133138

134139
plc.io.json.write_json(options)
135140

141+
synchronize_stream()
142+
136143
exp = pa_tbl.to_pandas()
137144

138145
# Convert everything to string to make
@@ -428,6 +435,8 @@ def test_read_json_from_device_buffers(table_data, num_buffers, stream):
428435
json_str.encode("utf-8"), plc.utils._get_stream(stream)
429436
)
430437

438+
synchronize_stream(stream)
439+
431440
options = (
432441
plc.io.json.JsonReaderOptions.builder(
433442
plc.io.SourceInfo([buf] * num_buffers)
@@ -471,6 +480,8 @@ def test_utf8_escaped_json_writer(tmp_path):
471480
)
472481
plc.io.json.write_json(options)
473482

483+
synchronize_stream()
484+
474485
output_string = path.read_text(encoding="utf-8").strip()
475486

476487
assert output_string == '[{"0":"C𝞵𝓓𝒻"}]'

python/pylibcudf/tests/io/test_orc.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33

44
import pyarrow as pa
@@ -8,6 +8,7 @@
88
assert_table_and_meta_eq,
99
get_bytes_from_source,
1010
make_source,
11+
synchronize_stream,
1112
)
1213

1314
from rmm.pylibrmm.device_buffer import DeviceBuffer
@@ -107,6 +108,8 @@ def test_read_orc_from_device_buffers(
107108
get_bytes_from_source(source), plc.utils._get_stream(stream)
108109
)
109110

111+
synchronize_stream(stream)
112+
110113
options = plc.io.orc.OrcReaderOptions.builder(
111114
plc.io.types.SourceInfo([buf] * num_buffers)
112115
).build()
@@ -179,6 +182,8 @@ def test_roundtrip_pa_table(
179182

180183
plc.io.orc.write_orc(options, stream)
181184

185+
synchronize_stream(stream)
186+
182187
read_table = pa.orc.read_table(str(tmpfile_name))
183188

184189
res = plc.io.types.TableWithMetadata(

python/pylibcudf/tests/io/test_parquet.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
1+
# SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION.
22
# SPDX-License-Identifier: Apache-2.0
33
import io
44

@@ -10,6 +10,7 @@
1010
assert_table_and_meta_eq,
1111
get_bytes_from_source,
1212
make_source,
13+
synchronize_stream,
1314
)
1415

1516
from rmm.pylibrmm.device_buffer import DeviceBuffer
@@ -192,6 +193,8 @@ def test_read_parquet_from_device_buffers(
192193
get_bytes_from_source(source), plc.utils._get_stream(stream)
193194
)
194195

196+
synchronize_stream(stream)
197+
195198
options = plc.io.parquet.ParquetReaderOptions.builder(
196199
plc.io.SourceInfo([buf] * num_buffers)
197200
).build()
@@ -289,6 +292,9 @@ def test_write_parquet(
289292
options.set_max_dictionary_size(max_dictionary_size)
290293

291294
result = plc.io.parquet.write_parquet(options, stream)
295+
296+
synchronize_stream(stream)
297+
292298
assert isinstance(result, memoryview)
293299

294300

0 commit comments

Comments
 (0)