Skip to content

Commit 701bd6c

Browse files
authored
pickling support (#5011)
Implements `__reduce__` and `__reduce_ex__` methods to enable pickling of Vortex arrays in Python. Arrays are serialized using the Vortex IPC format. For pickle protocol 5+ (Python 3.8+, PEP 574), uses PickleBuffer to keep array buffers separate from the main pickle stream rather than copying them inline. This enables us to use shared memory in the future to potentially zero-copy large arrays even across process boundaries. Protocol 4 and below serialise buffers inline as bytes. Both protocols share the same deserialization path via `decode_ipc_array_buffers`, which reconstructs arrays from IPC-encoded buffer lists (or memoryviews). --------- Signed-off-by: Onur Satici <[email protected]>
1 parent 99101b1 commit 701bd6c

File tree

11 files changed

+479
-4
lines changed

11 files changed

+479
-4
lines changed

Cargo.lock

Lines changed: 12 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ prost = "0.14"
163163
prost-build = "0.14"
164164
prost-types = "0.14"
165165
pyo3 = { version = "0.26.0" }
166+
pyo3-bytes = "0.4"
166167
pyo3-log = "0.13.0"
167168
rand = "0.9.0"
168169
rand_distr = "0.5"

vortex-python/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@ crate-type = ["rlib", "cdylib"]
2525
arrow-array = { workspace = true }
2626
arrow-data = { workspace = true }
2727
arrow-schema = { workspace = true }
28+
bytes = { workspace = true }
2829
itertools = { workspace = true }
2930
log = { workspace = true }
3031
object_store = { workspace = true, features = ["aws", "gcp", "azure", "http"] }
3132
parking_lot = { workspace = true }
3233
pyo3 = { workspace = true, features = ["abi3", "abi3-py311"] }
34+
pyo3-bytes = { workspace = true }
3335
pyo3-log = { workspace = true }
3436
tokio = { workspace = true, features = ["fs", "rt-multi-thread"] }
3537
url = { workspace = true }

vortex-python/benchmark/conftest.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import hashlib
55
import math
66
import os
7+
from typing import cast
78

89
import pyarrow as pa
910
import pytest
@@ -34,3 +35,9 @@ def vxf(tmpdir_factory: pytest.TempPathFactory, request: pytest.FixtureRequest)
3435
a = vx.array(pa.table(columns)) # pyright: ignore[reportCallIssue, reportUnknownArgumentType, reportArgumentType]
3536
vx.io.write(a, str(fname))
3637
return vx.open(str(fname))
38+
39+
40+
@pytest.fixture(scope="session", params=[10_000, 2_000_000], ids=["small", "large"])
41+
def array_fixture(request: pytest.FixtureRequest) -> vx.Array:
42+
size = cast(int, request.param)
43+
return vx.array(pa.table({"x": list(range(size))}))
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
import pickle
5+
6+
import pytest
7+
from pytest_benchmark.fixture import BenchmarkFixture # pyright: ignore[reportMissingTypeStubs]
8+
9+
import vortex as vx
10+
11+
12+
@pytest.mark.parametrize("protocol", [4, 5], ids=lambda p: f"p{p}") # pyright: ignore[reportAny]
13+
@pytest.mark.parametrize("operation", ["dumps", "loads", "roundtrip"])
14+
@pytest.mark.benchmark(disable_gc=True)
15+
def test_pickle(
16+
benchmark: BenchmarkFixture,
17+
array_fixture: vx.Array,
18+
protocol: int,
19+
operation: str,
20+
):
21+
benchmark.group = f"pickle_p{protocol}"
22+
23+
if operation == "dumps":
24+
benchmark(lambda: pickle.dumps(array_fixture, protocol=protocol))
25+
elif operation == "loads":
26+
pickled_data = pickle.dumps(array_fixture, protocol=protocol)
27+
benchmark(lambda: pickle.loads(pickled_data)) # pyright: ignore[reportAny]
28+
elif operation == "roundtrip":
29+
benchmark(lambda: pickle.loads(pickle.dumps(array_fixture, protocol=protocol))) # pyright: ignore[reportAny]

vortex-python/python/vortex/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,12 @@
7070
scalar,
7171
)
7272
from ._lib.serde import ArrayContext, ArrayParts # pyright: ignore[reportMissingModuleSource]
73-
from .arrays import Array, PyArray, array
73+
from .arrays import (
74+
Array,
75+
PyArray,
76+
_unpickle_array, # pyright: ignore[reportPrivateUsage]
77+
array,
78+
)
7479
from .file import VortexFile, open
7580
from .scan import RepeatedScan
7681

@@ -155,6 +160,8 @@
155160
# Serde
156161
"ArrayContext",
157162
"ArrayParts",
163+
# Pickle
164+
"_unpickle_array",
158165
# File
159166
"VortexFile",
160167
"open",

vortex-python/python/vortex/_lib/serde.pyi

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright the Vortex contributors
33

4+
from collections.abc import Sequence
45
from typing import final
56

67
import pyarrow as pa
78

9+
from .arrays import Array
810
from .dtype import DType
911

1012
@final
@@ -26,3 +28,8 @@ class ArrayParts:
2628
@final
2729
class ArrayContext:
2830
def __len__(self) -> int: ...
31+
32+
def decode_ipc_array(array_bytes: bytes, dtype_bytes: bytes) -> Array: ...
33+
def decode_ipc_array_buffers(
34+
array_buffers: Sequence[bytes | memoryview], dtype_buffers: Sequence[bytes | memoryview]
35+
) -> Array: ...

vortex-python/python/vortex/arrays.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,19 @@
33
from __future__ import annotations
44

55
import abc
6-
from collections.abc import Callable
6+
from collections.abc import Callable, Sequence
77
from typing import TYPE_CHECKING, Any
88

99
import pyarrow
1010
from typing_extensions import override
1111

1212
import vortex._lib.arrays as _arrays # pyright: ignore[reportMissingModuleSource]
1313
from vortex._lib.dtype import DType # pyright: ignore[reportMissingModuleSource]
14-
from vortex._lib.serde import ArrayContext, ArrayParts # pyright: ignore[reportMissingModuleSource]
14+
from vortex._lib.serde import ( # pyright: ignore[reportMissingModuleSource]
15+
ArrayContext,
16+
ArrayParts,
17+
decode_ipc_array_buffers,
18+
)
1519

1620
try:
1721
import pandas
@@ -466,3 +470,15 @@ def decode(cls, parts: ArrayParts, ctx: ArrayContext, dtype: DType, len: int) ->
466470
current array. Implementations of this function should validate this information, and then construct
467471
a new array.
468472
"""
473+
474+
475+
def _unpickle_array(array_buffers: Sequence[bytes | memoryview], dtype_buffers: Sequence[bytes | memoryview]) -> Array: # pyright: ignore[reportUnusedFunction]
476+
"""Unpickle a Vortex array from IPC-encoded buffer lists.
477+
478+
This is an internal function used by the pickle module for both protocol 4 and 5.
479+
480+
For protocol 4, receives list[bytes] from __reduce__.
481+
For protocol 5, receives list[PickleBuffer/memoryview] from __reduce_ex__.
482+
Both use decode_ipc_array_buffers which concatenates the buffers during deserialization.
483+
"""
484+
return decode_ipc_array_buffers(array_buffers, dtype_buffers)

vortex-python/src/arrays/mod.rs

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,17 @@ pub(crate) mod py;
1010
mod range_to_sequence;
1111

1212
use arrow_array::{Array as ArrowArray, ArrayRef as ArrowArrayRef};
13+
use pyo3::IntoPyObjectExt;
1314
use pyo3::exceptions::{PyTypeError, PyValueError};
1415
use pyo3::prelude::*;
1516
use pyo3::types::{PyDict, PyList, PyRange, PyRangeMethods};
17+
use pyo3_bytes::PyBytes;
1618
use vortex::arrays::ChunkedVTable;
1719
use vortex::arrow::IntoArrowArray;
1820
use vortex::compute::{Operator, compare, take};
1921
use vortex::dtype::{DType, Nullability, PType, match_each_integer_ptype};
2022
use vortex::error::VortexError;
23+
use vortex::ipc::messages::{EncoderMessage, MessageEncoder};
2124
use vortex::{Array, ArrayRef, ToCanonical};
2225

2326
use crate::arrays::native::PyNativeArray;
@@ -653,4 +656,76 @@ impl PyArray {
653656
.map(|buffer| buffer.to_vec())
654657
.collect())
655658
}
659+
660+
/// Support for Python's pickle protocol.
661+
///
662+
/// This method serializes the array using Vortex IPC format and returns
663+
/// the data needed for pickle to reconstruct the array.
664+
fn __reduce__<'py>(
665+
slf: &'py Bound<'py, Self>,
666+
) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
667+
let py = slf.py();
668+
let array = PyArrayRef::extract_bound(slf.as_any())?.into_inner();
669+
670+
let mut encoder = MessageEncoder::default();
671+
let buffers = encoder.encode(EncoderMessage::Array(&*array));
672+
673+
// Return buffers as a list instead of concatenating
674+
let array_buffers: Vec<Vec<u8>> = buffers.iter().map(|b| b.to_vec()).collect();
675+
676+
let dtype_buffers = encoder.encode(EncoderMessage::DType(array.dtype()));
677+
let dtype_buffers: Vec<Vec<u8>> = dtype_buffers.iter().map(|b| b.to_vec()).collect();
678+
679+
let vortex_module = PyModule::import(py, "vortex")?;
680+
let unpickle_fn = vortex_module.getattr("_unpickle_array")?;
681+
682+
let args = (array_buffers, dtype_buffers).into_pyobject(py)?;
683+
Ok((unpickle_fn, args.into_any()))
684+
}
685+
686+
/// Support for Python's pickle protocol for protocol >= 5
687+
///
688+
/// uses PickleBuffer for out-of-band buffer transfer,
689+
/// which potentially avoids copying large data buffers.
690+
fn __reduce_ex__<'py>(
691+
slf: &'py Bound<'py, Self>,
692+
protocol: i32,
693+
) -> PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)> {
694+
let py = slf.py();
695+
696+
if protocol < 5 {
697+
return Self::__reduce__(slf);
698+
}
699+
700+
let array = PyArrayRef::extract_bound(slf.as_any())?.into_inner();
701+
702+
let mut encoder = MessageEncoder::default();
703+
let array_buffers = encoder.encode(EncoderMessage::Array(&*array));
704+
let dtype_buffers = encoder.encode(EncoderMessage::DType(array.dtype()));
705+
706+
let pickle_module = PyModule::import(py, "pickle")?;
707+
let pickle_buffer_class = pickle_module.getattr("PickleBuffer")?;
708+
709+
let mut pickle_buffers = Vec::new();
710+
for buf in array_buffers.into_iter() {
711+
// PyBytes wraps bytes::Bytes and implements the buffer protocol
712+
// This allows PickleBuffer to reference the data without copying
713+
let py_bytes = PyBytes::new(buf).into_py_any(py)?;
714+
let pickle_buffer = pickle_buffer_class.call1((py_bytes,))?;
715+
pickle_buffers.push(pickle_buffer);
716+
}
717+
718+
let mut dtype_pickle_buffers = Vec::new();
719+
for buf in dtype_buffers.into_iter() {
720+
let py_bytes = PyBytes::new(buf).into_py_any(py)?;
721+
let pickle_buffer = pickle_buffer_class.call1((py_bytes,))?;
722+
dtype_pickle_buffers.push(pickle_buffer);
723+
}
724+
725+
let vortex_module = PyModule::import(py, "vortex")?;
726+
let unpickle_fn = vortex_module.getattr("_unpickle_array")?;
727+
728+
let args = (pickle_buffers, dtype_pickle_buffers).into_pyobject(py)?;
729+
Ok((unpickle_fn, args.into_any()))
730+
}
656731
}

0 commit comments

Comments
 (0)