Skip to content

Commit 7eade9d

Browse files
authored
PyVortex to use new scan executor (#2908)
1 parent 15b8846 commit 7eade9d

File tree

6 files changed

+235
-221
lines changed

6 files changed

+235
-221
lines changed

pyvortex/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ dependencies = ["pyarrow>=17.0.0", "substrait>=0.23.0"]
88
requires-python = ">= 3.10"
99

1010
[project.optional-dependencies]
11-
polars = ["polars>=1.24.0"]
11+
polars = ["polars>=1.27.0"]
1212
pandas = ["pandas>=2.2.0"]
1313
numpy = ["numpy>=1.26.0"]
1414
duckdb = ["duckdb>=1.1.2"]

pyvortex/python/vortex/file.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,7 @@ def _io_source(
3030
if predicate is not None:
3131
predicate = polars_to_vortex(predicate)
3232

33-
reader = self.to_arrow(
34-
projection=with_columns,
35-
expr=predicate,
36-
batch_size=batch_size or 8192,
37-
)
33+
reader = self.to_arrow(projection=with_columns, expr=predicate)
3834

3935
for batch in reader:
4036
batch = pl.DataFrame._from_arrow(batch, rechunk=False)
@@ -46,7 +42,7 @@ def _io_source(
4642
data=pa.RecordBatch.from_arrays(
4743
[pa.array([], type=field.type) for field in reader.schema],
4844
schema=reader.schema,
49-
)
45+
),
5046
)
5147

5248
return register_io_source(_io_source, schema=schema)

pyvortex/python/vortex/polars_.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,13 @@ def polars_to_vortex(expr: pl.Expr) -> ve.Expr:
2525
"LogicalOr": operator.or_,
2626
}
2727

28+
29+
def _unsupported(v, name: str):
30+
raise ValueError(f"Unsupported Polars expression {name}: {v}")
31+
32+
2833
_LITERAL_TYPES = {
34+
"Boolean": lambda v: vx.bool_(nullable=v is None),
2935
"Int": lambda v: vx.int_(64, nullable=v is None),
3036
"Int8": lambda v: vx.int_(8, nullable=v is None),
3137
"Int16": lambda v: vx.int_(16, nullable=v is None),
@@ -37,7 +43,6 @@ def polars_to_vortex(expr: pl.Expr) -> ve.Expr:
3743
"UInt64": lambda v: vx.uint(64, nullable=v is None),
3844
"Float32": lambda v: vx.float_(32, nullable=v is None),
3945
"Float64": lambda v: vx.float_(64, nullable=v is None),
40-
"Boolean": lambda v: vx.bool_(nullable=v is None),
4146
"Null": lambda v: vx.null(),
4247
"String": lambda v: vx.utf8(nullable=v is None),
4348
"Binary": lambda v: vx.binary(nullable=v is None),
@@ -59,11 +64,28 @@ def _polars_to_vortex(expr: dict) -> ve.Expr:
5964
if "Column" in expr:
6065
return ve.column(expr["Column"])
6166

67+
# See https://github.com/pola-rs/polars/pull/21849)
68+
if "Scalar" in expr:
69+
dtype = expr["Scalar"]["dtype"] # DType
70+
value = expr["Scalar"]["value"] # AnyValue
71+
72+
if "Null" in value:
73+
value = None
74+
elif "StringOwned" in value:
75+
value = value["StringOwned"]
76+
else:
77+
raise ValueError(f"Unsupported Polars scalar value type {value}")
78+
79+
return ve.literal(_LITERAL_TYPES[dtype](value), value)
80+
6281
if "Literal" in expr:
6382
expr = expr["Literal"]
6483

6584
literal_type = next(iter(expr.keys()), None)
6685

86+
if literal_type == "Scalar":
87+
return _polars_to_vortex(expr)
88+
6789
# Special-case Series
6890
if literal_type == "Series":
6991
expr = pl.Expr.from_json(json.dumps({"Literal": expr}))
@@ -91,6 +113,12 @@ def _polars_to_vortex(expr: dict) -> ve.Expr:
91113
dtype = vx.ext("vortex.timestamp", vx.int_(64, nullable=value is None), metadata=metadata)
92114
return ve.literal(dtype, value)
93115

116+
# Unwrap 'Dyn' scalars, whose type hasn't been established yet.
117+
# (post https://github.com/pola-rs/polars/pull/21849)
118+
if literal_type == "Dyn":
119+
expr = expr["Dyn"]
120+
literal_type = next(iter(expr.keys()), None)
121+
94122
if literal_type not in _LITERAL_TYPES:
95123
raise NotImplementedError(f"Unsupported Polars literal type: {literal_type}")
96124
value = expr[literal_type]
@@ -111,6 +139,11 @@ def _polars_to_vortex(expr: dict) -> ve.Expr:
111139

112140
# Vortex doesn't support is-in, so we need to construct a series of ORs?
113141

142+
if "StringExpr" in fn:
143+
fn = fn["StringExpr"]
144+
if "Contains" in fn:
145+
raise ValueError("Unsupported Polars StringExpr.Contains")
146+
114147
raise NotImplementedError(f"Unsupported Polars function: {fn}")
115148

116149
raise NotImplementedError(f"Unsupported Polars expression: {expr}")

pyvortex/src/dataset.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ pub async fn read_array_from_reader(
4949
scan = scan.with_row_indices(indices);
5050
}
5151

52-
let stream = scan.into_array_stream()?;
52+
let stream = scan.spawn_tokio(TOKIO_RUNTIME.handle().clone())?;
5353
let dtype = stream.dtype().clone();
5454

5555
let all_arrays = stream.try_collect::<Vec<_>>().await?;

pyvortex/src/file.rs

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,20 @@ use std::sync::Arc;
22

33
use arrow::array::RecordBatchReader;
44
use arrow::pyarrow::IntoPyArrow;
5-
use futures::{SinkExt, StreamExt};
65
use pyo3::exceptions::PyTypeError;
76
use pyo3::prelude::*;
87
use pyo3::types::PyList;
98
use vortex::ToCanonical;
109
use vortex::compute::try_cast;
1110
use vortex::dtype::Nullability::NonNullable;
1211
use vortex::dtype::{DType, PType};
13-
use vortex::error::{VortexExpect, vortex_err};
12+
use vortex::error::VortexError;
1413
use vortex::expr::{ExprRef, ident, select};
1514
use vortex::file::scan::SplitBy;
1615
use vortex::file::segments::MokaSegmentCache;
1716
use vortex::file::{VortexFile, VortexOpenOptions};
1817
use vortex::io::TokioFile;
19-
use vortex::stream::{ArrayStream, ArrayStreamAdapter, ArrayStreamExt};
18+
use vortex::stream::ArrayStreamExt;
2019

2120
use crate::arrays::PyArrayRef;
2221
use crate::dataset::PyVortexDataset;
@@ -184,37 +183,23 @@ impl PyVortexFile {
184183
expr: Option<PyExpr>,
185184
batch_size: Option<usize>,
186185
) -> PyResult<PyObject> {
187-
let mut builder = slf
188-
.get()
189-
.vxf
190-
.scan()?
191-
.with_canonicalize(true)
192-
.with_some_filter(expr.map(|e| e.into_inner()))
193-
.with_projection(projection.map(|p| p.0).unwrap_or_else(ident));
194-
195-
if let Some(batch_size) = batch_size {
196-
builder = builder.with_split_by(SplitBy::RowCount(batch_size));
197-
}
198-
199-
let stream = ArrayStreamExt::boxed(builder.spawn_tokio(TOKIO_RUNTIME.handle().clone())?);
200-
let dtype = stream.dtype().clone();
201-
202-
// The I/O of the array stream won't make progress unless it's polled. So we need to spawn it.
203-
let (mut send, recv) = futures::channel::mpsc::unbounded();
204-
205-
TOKIO_RUNTIME
206-
.block_on(TOKIO_RUNTIME.spawn(async move {
207-
let mut stream = stream;
208-
while let Some(batch) = stream.next().await {
209-
send.send(batch)
210-
.await
211-
.map_err(|e| vortex_err!("Send failed {}", e))
212-
.vortex_expect("send failed");
213-
}
214-
}))
215-
.vortex_expect("failed to spawn stream");
216-
217-
let stream = ArrayStreamAdapter::new(dtype, recv);
186+
let vxf = slf.get().vxf.clone();
187+
188+
let stream = slf.py().allow_threads(|| {
189+
let mut builder = vxf
190+
.scan()?
191+
.with_canonicalize(true)
192+
.with_some_filter(expr.map(|e| e.into_inner()))
193+
.with_projection(projection.map(|p| p.0).unwrap_or_else(ident));
194+
195+
if let Some(batch_size) = batch_size {
196+
builder = builder.with_split_by(SplitBy::RowCount(batch_size));
197+
}
198+
199+
Ok::<_, VortexError>(ArrayStreamExt::boxed(
200+
builder.spawn_tokio(TOKIO_RUNTIME.handle().clone())?,
201+
))
202+
})?;
218203

219204
let iter = ArrayStreamToIterator::new(stream);
220205
let rbr: Box<dyn RecordBatchReader + Send> =

0 commit comments

Comments
 (0)