Skip to content

Commit c50a5ee

Browse files
authored
IntoArrow compute function (#2113)
Create a ToArrow compute function that takes a target Arrow DType. This allows us in the future to support e.g. `DictArray::to_arrow(DataType::Dictionary(...))` which currently isn't possible with the into_canonical hack we use for VarBin.
1 parent ca748db commit c50a5ee

File tree

46 files changed

+727
-491
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+727
-491
lines changed

bench-vortex/benches/compress.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use regex::Regex;
2929
use simplelog::*;
3030
use tokio::runtime::{Handle, Runtime};
3131
use vortex::array::{ChunkedArray, StructArray};
32+
use vortex::arrow::IntoArrowArray;
3233
use vortex::dtype::FieldName;
3334
use vortex::error::VortexResult;
3435
use vortex::file::{ExecutionMode, Scan, VortexOpenOptions, VortexWriteOptions};
@@ -131,7 +132,7 @@ fn vortex_decompress_read(runtime: &Runtime, buf: Bytes) -> VortexResult<Vec<Arr
131132
.try_collect::<Vec<_>>()
132133
.await?
133134
.into_iter()
134-
.map(|a| a.into_arrow())
135+
.map(|a| a.into_arrow_preferred())
135136
.collect::<VortexResult<Vec<_>>>()
136137
})
137138
}

bench-vortex/src/lib.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,10 +344,10 @@ mod test {
344344
use arrow_array::{ArrayRef as ArrowArrayRef, StructArray as ArrowStructArray};
345345
use log::LevelFilter;
346346
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;
347-
use vortex::arrow::FromArrowArray;
347+
use vortex::arrow::{FromArrowArray, IntoArrowArray};
348348
use vortex::compress::CompressionStrategy;
349349
use vortex::sampling_compressor::SamplingCompressor;
350-
use vortex::{ArrayData, IntoCanonical};
350+
use vortex::ArrayData;
351351

352352
use crate::taxi_data::taxi_data_parquet;
353353
use crate::{compress_taxi_data, setup_logger};
@@ -370,7 +370,7 @@ mod test {
370370
let struct_arrow: ArrowStructArray = record_batch.into();
371371
let arrow_array: ArrowArrayRef = Arc::new(struct_arrow);
372372
let vortex_array = ArrayData::from_arrow(arrow_array.clone(), false);
373-
let vortex_as_arrow = vortex_array.into_arrow().unwrap();
373+
let vortex_as_arrow = vortex_array.into_arrow_preferred().unwrap();
374374
assert_eq!(vortex_as_arrow.deref(), arrow_array.deref());
375375
}
376376
}
@@ -391,7 +391,7 @@ mod test {
391391
let vortex_array = ArrayData::from_arrow(arrow_array.clone(), false);
392392

393393
let compressed = compressor.compress(&vortex_array).unwrap();
394-
let compressed_as_arrow = compressed.into_arrow().unwrap();
394+
let compressed_as_arrow = compressed.into_arrow_preferred().unwrap();
395395
assert_eq!(compressed_as_arrow.deref(), arrow_array.deref());
396396
}
397397
}

fuzz/fuzz_targets/file_io.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ use bytes::Bytes;
77
use futures_util::TryStreamExt;
88
use libfuzzer_sys::{fuzz_target, Corpus};
99
use vortex_array::array::ChunkedArray;
10+
use vortex_array::arrow::IntoArrowArray;
1011
use vortex_array::compute::{compare, Operator};
11-
use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant, IntoCanonical};
12+
use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant};
1213
use vortex_dtype::DType;
1314
use vortex_file::{Scan, VortexOpenOptions, VortexWriteOptions};
1415
use vortex_sampling_compressor::ALL_ENCODINGS_CONTEXT;
@@ -91,8 +92,8 @@ fn compare_struct(expected: ArrayData, actual: ArrayData) {
9192
return;
9293
}
9394

94-
let arrow_lhs = expected.clone().into_arrow().unwrap();
95-
let arrow_rhs = actual.clone().into_arrow().unwrap();
95+
let arrow_lhs = expected.clone().into_arrow_preferred().unwrap();
96+
let arrow_rhs = actual.clone().into_arrow_preferred().unwrap();
9697

9798
let cmp_fn = make_comparator(&arrow_lhs, &arrow_rhs, SortOptions::default()).unwrap();
9899

pyvortex/src/array.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@ use pyo3::exceptions::PyValueError;
44
use pyo3::prelude::*;
55
use pyo3::types::{IntoPyDict, PyInt, PyList};
66
use vortex::array::ChunkedArray;
7+
use vortex::arrow::{infer_data_type, IntoArrowArray};
78
use vortex::compute::{compare, fill_forward, scalar_at, slice, take, Operator};
89
use vortex::mask::Mask;
9-
use vortex::{ArrayDType, ArrayData, IntoCanonical};
10+
use vortex::{ArrayDType, ArrayData};
1011

1112
use crate::dtype::PyDType;
1213
use crate::python_repr::PythonRepr;
@@ -119,9 +120,13 @@ impl PyArray {
119120
let vortex = &self_.inner;
120121

121122
if let Ok(chunked_array) = ChunkedArray::try_from(vortex.clone()) {
123+
// We figure out a single Arrow Data Type to convert all chunks into, otherwise
124+
// the preferred type of each chunk may be different.
125+
let arrow_dtype = infer_data_type(chunked_array.dtype())?;
126+
122127
let chunks: Vec<ArrayRef> = chunked_array
123128
.chunks()
124-
.map(|chunk| -> PyResult<ArrayRef> { Ok(chunk.into_arrow()?) })
129+
.map(|chunk| -> PyResult<ArrayRef> { Ok(chunk.into_arrow(&arrow_dtype)?) })
125130
.collect::<PyResult<Vec<ArrayRef>>>()?;
126131
if chunks.is_empty() {
127132
return Err(PyValueError::new_err("No chunks in array"));
@@ -141,7 +146,7 @@ impl PyArray {
141146
} else {
142147
Ok(vortex
143148
.clone()
144-
.into_arrow()?
149+
.into_arrow_preferred()?
145150
.into_data()
146151
.to_pyarrow(py)?
147152
.into_bound(py))

vortex-array/src/array/bool/compute/mod.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::array::BoolEncoding;
22
use crate::compute::{
33
BinaryBooleanFn, FillForwardFn, FillNullFn, FilterFn, InvertFn, ScalarAtFn, SliceFn, TakeFn,
4+
ToArrowFn,
45
};
56
use crate::vtable::ComputeVTable;
67
use crate::ArrayData;
@@ -13,6 +14,7 @@ mod invert;
1314
mod scalar_at;
1415
mod slice;
1516
mod take;
17+
mod to_arrow;
1618

1719
impl ComputeVTable for BoolEncoding {
1820
fn binary_boolean_fn(&self) -> Option<&dyn BinaryBooleanFn<ArrayData>> {
@@ -50,4 +52,8 @@ impl ComputeVTable for BoolEncoding {
5052
fn take_fn(&self) -> Option<&dyn TakeFn<ArrayData>> {
5153
Some(self)
5254
}
55+
56+
fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<ArrayData>> {
57+
Some(self)
58+
}
5359
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
use std::sync::Arc;
2+
3+
use arrow_array::{ArrayRef, BooleanArray as ArrowBoolArray};
4+
use arrow_schema::DataType;
5+
use vortex_error::{vortex_bail, VortexResult};
6+
7+
use crate::array::{BoolArray, BoolEncoding};
8+
use crate::compute::ToArrowFn;
9+
use crate::validity::ArrayValidity;
10+
use crate::IntoArrayData;
11+
12+
impl ToArrowFn<BoolArray> for BoolEncoding {
13+
fn to_arrow(&self, array: &BoolArray, data_type: &DataType) -> VortexResult<Option<ArrayRef>> {
14+
if data_type != &DataType::Boolean {
15+
vortex_bail!("Unsupported data type: {data_type}");
16+
}
17+
18+
Ok(Some(Arc::new(ArrowBoolArray::new(
19+
array.boolean_buffer(),
20+
array.logical_validity()?.to_null_buffer(),
21+
))))
22+
}
23+
}

vortex-array/src/array/chunked/canonical.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use crate::array::null::NullArray;
99
use crate::array::primitive::PrimitiveArray;
1010
use crate::array::struct_::StructArray;
1111
use crate::array::{BinaryView, BoolArray, ListArray, VarBinViewArray};
12+
use crate::arrow::IntoArrowArray;
1213
use crate::compute::{scalar_at, slice, try_cast};
1314
use crate::validity::Validity;
1415
use crate::{

vortex-array/src/array/constant/canonical.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use vortex_scalar::{BinaryScalar, BoolScalar, ExtScalar, Utf8Scalar};
88
use crate::array::constant::ConstantArray;
99
use crate::array::primitive::PrimitiveArray;
1010
use crate::array::{BinaryView, BoolArray, ExtensionArray, NullArray, VarBinViewArray};
11+
use crate::arrow::IntoArrowArray;
1112
use crate::validity::Validity;
1213
use crate::{ArrayDType, ArrayLen, Canonical, IntoArrayData, IntoCanonical};
1314

vortex-array/src/array/extension/compute/mod.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
mod compare;
2+
mod to_arrow;
23

34
use vortex_error::VortexResult;
45
use vortex_scalar::Scalar;
56

67
use crate::array::extension::ExtensionArray;
78
use crate::array::ExtensionEncoding;
8-
use crate::compute::{scalar_at, slice, take, CastFn, CompareFn, ScalarAtFn, SliceFn, TakeFn};
9+
use crate::compute::{
10+
scalar_at, slice, take, CastFn, CompareFn, ScalarAtFn, SliceFn, TakeFn, ToArrowFn,
11+
};
912
use crate::variants::ExtensionArrayTrait;
1013
use crate::vtable::ComputeVTable;
1114
use crate::{ArrayData, IntoArrayData};
@@ -33,6 +36,10 @@ impl ComputeVTable for ExtensionEncoding {
3336
fn take_fn(&self) -> Option<&dyn TakeFn<ArrayData>> {
3437
Some(self)
3538
}
39+
40+
fn to_arrow_fn(&self) -> Option<&dyn ToArrowFn<ArrayData>> {
41+
Some(self)
42+
}
3643
}
3744

3845
impl ScalarAtFn<ExtensionArray> for ExtensionEncoding {
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
use std::sync::Arc;
2+
3+
use arrow_array::{
4+
ArrayRef, Date32Array, Date64Array, Time32MillisecondArray, Time32SecondArray,
5+
Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray,
6+
TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray,
7+
};
8+
use arrow_schema::DataType;
9+
use vortex_datetime_dtype::{is_temporal_ext_type, TemporalMetadata, TimeUnit};
10+
use vortex_dtype::{DType, NativePType};
11+
use vortex_error::{vortex_bail, VortexResult};
12+
13+
use crate::array::{ExtensionArray, ExtensionEncoding, TemporalArray};
14+
use crate::canonical::IntoArrayVariant;
15+
use crate::compute::{to_arrow, try_cast, ToArrowFn};
16+
use crate::validity::ArrayValidity;
17+
use crate::{ArrayDType, IntoArrayData};
18+
19+
impl ToArrowFn<ExtensionArray> for ExtensionEncoding {
20+
fn to_arrow(
21+
&self,
22+
array: &ExtensionArray,
23+
data_type: &DataType,
24+
) -> VortexResult<Option<ArrayRef>> {
25+
// NOTE(ngates): this is really gross... but I guess it's ok given how tightly integrated
26+
// we are with Arrow.
27+
if is_temporal_ext_type(array.id()) {
28+
temporal_to_arrow(TemporalArray::try_from(array.clone().into_array())?).map(Some)
29+
} else {
30+
// Convert storage array directly into arrow, losing type information
31+
// that will let us round-trip.
32+
// TODO(aduffy): https://github.com/spiraldb/vortex/issues/1167
33+
to_arrow(array.storage(), data_type).map(Some)
34+
}
35+
}
36+
}
37+
38+
fn temporal_to_arrow(temporal_array: TemporalArray) -> VortexResult<ArrayRef> {
39+
macro_rules! extract_temporal_values {
40+
($values:expr, $prim:ty) => {{
41+
let temporal_values = try_cast(
42+
$values,
43+
&DType::Primitive(<$prim as NativePType>::PTYPE, $values.dtype().nullability()),
44+
)?
45+
.into_primitive()?;
46+
let nulls = temporal_values.logical_validity()?.to_null_buffer();
47+
let scalars = temporal_values.into_buffer().into_arrow_scalar_buffer();
48+
49+
(scalars, nulls)
50+
}};
51+
}
52+
53+
Ok(match temporal_array.temporal_metadata() {
54+
TemporalMetadata::Date(time_unit) => match time_unit {
55+
TimeUnit::D => {
56+
let (scalars, nulls) =
57+
extract_temporal_values!(&temporal_array.temporal_values(), i32);
58+
Arc::new(Date32Array::new(scalars, nulls))
59+
}
60+
TimeUnit::Ms => {
61+
let (scalars, nulls) =
62+
extract_temporal_values!(&temporal_array.temporal_values(), i64);
63+
Arc::new(Date64Array::new(scalars, nulls))
64+
}
65+
_ => vortex_bail!(
66+
"Invalid TimeUnit {time_unit} for {}",
67+
temporal_array.ext_dtype().id()
68+
),
69+
},
70+
TemporalMetadata::Time(time_unit) => match time_unit {
71+
TimeUnit::S => {
72+
let (scalars, nulls) =
73+
extract_temporal_values!(&temporal_array.temporal_values(), i32);
74+
Arc::new(Time32SecondArray::new(scalars, nulls))
75+
}
76+
TimeUnit::Ms => {
77+
let (scalars, nulls) =
78+
extract_temporal_values!(&temporal_array.temporal_values(), i32);
79+
Arc::new(Time32MillisecondArray::new(scalars, nulls))
80+
}
81+
TimeUnit::Us => {
82+
let (scalars, nulls) =
83+
extract_temporal_values!(&temporal_array.temporal_values(), i64);
84+
Arc::new(Time64MicrosecondArray::new(scalars, nulls))
85+
}
86+
TimeUnit::Ns => {
87+
let (scalars, nulls) =
88+
extract_temporal_values!(&temporal_array.temporal_values(), i64);
89+
Arc::new(Time64NanosecondArray::new(scalars, nulls))
90+
}
91+
_ => vortex_bail!(
92+
"Invalid TimeUnit {time_unit} for {}",
93+
temporal_array.ext_dtype().id()
94+
),
95+
},
96+
TemporalMetadata::Timestamp(time_unit, _) => {
97+
let (scalars, nulls) = extract_temporal_values!(&temporal_array.temporal_values(), i64);
98+
match time_unit {
99+
TimeUnit::Ns => Arc::new(TimestampNanosecondArray::new(scalars, nulls)),
100+
TimeUnit::Us => Arc::new(TimestampMicrosecondArray::new(scalars, nulls)),
101+
TimeUnit::Ms => Arc::new(TimestampMillisecondArray::new(scalars, nulls)),
102+
TimeUnit::S => Arc::new(TimestampSecondArray::new(scalars, nulls)),
103+
_ => vortex_bail!(
104+
"Invalid TimeUnit {time_unit} for {}",
105+
temporal_array.ext_dtype().id()
106+
),
107+
}
108+
}
109+
})
110+
}

0 commit comments

Comments
 (0)