|
| 1 | +use std::sync::Arc; |
| 2 | + |
| 3 | +use arrow_array::{ |
| 4 | + ArrayRef, Date32Array, Date64Array, Time32MillisecondArray, Time32SecondArray, |
| 5 | + Time64MicrosecondArray, Time64NanosecondArray, TimestampMicrosecondArray, |
| 6 | + TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, |
| 7 | +}; |
| 8 | +use arrow_schema::DataType; |
| 9 | +use vortex_datetime_dtype::{is_temporal_ext_type, TemporalMetadata, TimeUnit}; |
| 10 | +use vortex_dtype::{DType, NativePType}; |
| 11 | +use vortex_error::{vortex_bail, VortexResult}; |
| 12 | + |
| 13 | +use crate::array::{ExtensionArray, ExtensionEncoding, TemporalArray}; |
| 14 | +use crate::canonical::IntoArrayVariant; |
| 15 | +use crate::compute::{to_arrow, try_cast, ToArrowFn}; |
| 16 | +use crate::validity::ArrayValidity; |
| 17 | +use crate::{ArrayDType, IntoArrayData}; |
| 18 | + |
| 19 | +impl ToArrowFn<ExtensionArray> for ExtensionEncoding { |
| 20 | + fn to_arrow( |
| 21 | + &self, |
| 22 | + array: &ExtensionArray, |
| 23 | + data_type: &DataType, |
| 24 | + ) -> VortexResult<Option<ArrayRef>> { |
| 25 | + // NOTE(ngates): this is really gross... but I guess it's ok given how tightly integrated |
| 26 | + // we are with Arrow. |
| 27 | + if is_temporal_ext_type(array.id()) { |
| 28 | + temporal_to_arrow(TemporalArray::try_from(array.clone().into_array())?).map(Some) |
| 29 | + } else { |
| 30 | + // Convert storage array directly into arrow, losing type information |
| 31 | + // that will let us round-trip. |
| 32 | + // TODO(aduffy): https://github.com/spiraldb/vortex/issues/1167 |
| 33 | + to_arrow(array.storage(), data_type).map(Some) |
| 34 | + } |
| 35 | + } |
| 36 | +} |
| 37 | + |
| 38 | +fn temporal_to_arrow(temporal_array: TemporalArray) -> VortexResult<ArrayRef> { |
| 39 | + macro_rules! extract_temporal_values { |
| 40 | + ($values:expr, $prim:ty) => {{ |
| 41 | + let temporal_values = try_cast( |
| 42 | + $values, |
| 43 | + &DType::Primitive(<$prim as NativePType>::PTYPE, $values.dtype().nullability()), |
| 44 | + )? |
| 45 | + .into_primitive()?; |
| 46 | + let nulls = temporal_values.logical_validity()?.to_null_buffer(); |
| 47 | + let scalars = temporal_values.into_buffer().into_arrow_scalar_buffer(); |
| 48 | + |
| 49 | + (scalars, nulls) |
| 50 | + }}; |
| 51 | + } |
| 52 | + |
| 53 | + Ok(match temporal_array.temporal_metadata() { |
| 54 | + TemporalMetadata::Date(time_unit) => match time_unit { |
| 55 | + TimeUnit::D => { |
| 56 | + let (scalars, nulls) = |
| 57 | + extract_temporal_values!(&temporal_array.temporal_values(), i32); |
| 58 | + Arc::new(Date32Array::new(scalars, nulls)) |
| 59 | + } |
| 60 | + TimeUnit::Ms => { |
| 61 | + let (scalars, nulls) = |
| 62 | + extract_temporal_values!(&temporal_array.temporal_values(), i64); |
| 63 | + Arc::new(Date64Array::new(scalars, nulls)) |
| 64 | + } |
| 65 | + _ => vortex_bail!( |
| 66 | + "Invalid TimeUnit {time_unit} for {}", |
| 67 | + temporal_array.ext_dtype().id() |
| 68 | + ), |
| 69 | + }, |
| 70 | + TemporalMetadata::Time(time_unit) => match time_unit { |
| 71 | + TimeUnit::S => { |
| 72 | + let (scalars, nulls) = |
| 73 | + extract_temporal_values!(&temporal_array.temporal_values(), i32); |
| 74 | + Arc::new(Time32SecondArray::new(scalars, nulls)) |
| 75 | + } |
| 76 | + TimeUnit::Ms => { |
| 77 | + let (scalars, nulls) = |
| 78 | + extract_temporal_values!(&temporal_array.temporal_values(), i32); |
| 79 | + Arc::new(Time32MillisecondArray::new(scalars, nulls)) |
| 80 | + } |
| 81 | + TimeUnit::Us => { |
| 82 | + let (scalars, nulls) = |
| 83 | + extract_temporal_values!(&temporal_array.temporal_values(), i64); |
| 84 | + Arc::new(Time64MicrosecondArray::new(scalars, nulls)) |
| 85 | + } |
| 86 | + TimeUnit::Ns => { |
| 87 | + let (scalars, nulls) = |
| 88 | + extract_temporal_values!(&temporal_array.temporal_values(), i64); |
| 89 | + Arc::new(Time64NanosecondArray::new(scalars, nulls)) |
| 90 | + } |
| 91 | + _ => vortex_bail!( |
| 92 | + "Invalid TimeUnit {time_unit} for {}", |
| 93 | + temporal_array.ext_dtype().id() |
| 94 | + ), |
| 95 | + }, |
| 96 | + TemporalMetadata::Timestamp(time_unit, _) => { |
| 97 | + let (scalars, nulls) = extract_temporal_values!(&temporal_array.temporal_values(), i64); |
| 98 | + match time_unit { |
| 99 | + TimeUnit::Ns => Arc::new(TimestampNanosecondArray::new(scalars, nulls)), |
| 100 | + TimeUnit::Us => Arc::new(TimestampMicrosecondArray::new(scalars, nulls)), |
| 101 | + TimeUnit::Ms => Arc::new(TimestampMillisecondArray::new(scalars, nulls)), |
| 102 | + TimeUnit::S => Arc::new(TimestampSecondArray::new(scalars, nulls)), |
| 103 | + _ => vortex_bail!( |
| 104 | + "Invalid TimeUnit {time_unit} for {}", |
| 105 | + temporal_array.ext_dtype().id() |
| 106 | + ), |
| 107 | + } |
| 108 | + } |
| 109 | + }) |
| 110 | +} |
0 commit comments