|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +use std::ops::Deref; |
| 5 | +use std::sync::Arc; |
| 6 | + |
| 7 | +use vortex_dtype::Nullability::NonNullable; |
| 8 | +use vortex_dtype::{ |
| 9 | + DType, NativeDecimalType, NativePType, match_each_decimal_value_type, match_each_native_ptype, |
| 10 | +}; |
| 11 | +use vortex_error::VortexExpect; |
| 12 | +use vortex_vector::binaryview::{BinaryViewType, BinaryViewVector}; |
| 13 | +use vortex_vector::bool::BoolVector; |
| 14 | +use vortex_vector::decimal::{DVector, DecimalVector}; |
| 15 | +use vortex_vector::fixed_size_list::FixedSizeListVector; |
| 16 | +use vortex_vector::listview::ListViewVector; |
| 17 | +use vortex_vector::null::NullVector; |
| 18 | +use vortex_vector::primitive::{PVector, PrimitiveVector}; |
| 19 | +use vortex_vector::struct_::StructVector; |
| 20 | +use vortex_vector::{Vector, VectorOps}; |
| 21 | + |
| 22 | +use crate::arrays::{ |
| 23 | + BoolArray, DecimalArray, ExtensionArray, FixedSizeListArray, ListViewArray, NullArray, |
| 24 | + PrimitiveArray, StructArray, VarBinViewArray, |
| 25 | +}; |
| 26 | +use crate::validity::Validity; |
| 27 | +use crate::{ArrayRef, IntoArray}; |
| 28 | + |
| 29 | +/// Trait for converting vector types into arrays. |
| 30 | +pub trait VectorIntoArray { |
| 31 | + /// Converts the vector into an array of the specified data type. |
| 32 | + fn into_array(self, dtype: DType) -> ArrayRef; |
| 33 | +} |
| 34 | + |
| 35 | +impl VectorIntoArray for Vector { |
| 36 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 37 | + match dtype { |
| 38 | + DType::Null => self.into_null().into_array(dtype), |
| 39 | + DType::Bool(_) => self.into_bool().into_array(dtype), |
| 40 | + DType::Primitive(..) => self.into_primitive().into_array(dtype), |
| 41 | + DType::Decimal(..) => self.into_decimal().into_array(dtype), |
| 42 | + DType::Utf8(_) => self.into_string().into_array(dtype), |
| 43 | + DType::Binary(_) => self.into_binary().into_array(dtype), |
| 44 | + DType::List(..) => self.into_list().into_array(dtype), |
| 45 | + DType::FixedSizeList(..) => self.into_fixed_size_list().into_array(dtype), |
| 46 | + DType::Struct(..) => self.into_struct().into_array(dtype), |
| 47 | + DType::Extension(ext_dtype) => { |
| 48 | + let storage = self.into_array(ext_dtype.storage_dtype().clone()); |
| 49 | + ExtensionArray::new(ext_dtype, storage).into_array() |
| 50 | + } |
| 51 | + } |
| 52 | + } |
| 53 | +} |
| 54 | + |
| 55 | +impl VectorIntoArray for NullVector { |
| 56 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 57 | + assert!(matches!(dtype, DType::Null)); |
| 58 | + NullArray::new(self.len()).into_array() |
| 59 | + } |
| 60 | +} |
| 61 | + |
| 62 | +impl VectorIntoArray for BoolVector { |
| 63 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 64 | + assert!(matches!(dtype, DType::Bool(_))); |
| 65 | + |
| 66 | + let (bits, validity) = self.into_parts(); |
| 67 | + BoolArray::from_bit_buffer(bits, Validity::from_mask(validity, dtype.nullability())) |
| 68 | + .into_array() |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +impl VectorIntoArray for PrimitiveVector { |
| 73 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 74 | + match_each_native_ptype!(self.ptype(), |T| { |
| 75 | + <T as NativePType>::downcast(self).into_array(dtype) |
| 76 | + }) |
| 77 | + } |
| 78 | +} |
| 79 | + |
| 80 | +impl<T: NativePType> VectorIntoArray for PVector<T> { |
| 81 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 82 | + assert!(matches!(dtype, DType::Primitive(_, _))); |
| 83 | + assert_eq!(T::PTYPE, dtype.as_ptype()); |
| 84 | + |
| 85 | + let (values, validity) = self.into_parts(); |
| 86 | + // SAFETY: vectors maintain all invariants required for array creation |
| 87 | + unsafe { |
| 88 | + PrimitiveArray::new_unchecked::<T>( |
| 89 | + values, |
| 90 | + Validity::from_mask(validity, dtype.nullability()), |
| 91 | + ) |
| 92 | + } |
| 93 | + .into_array() |
| 94 | + } |
| 95 | +} |
| 96 | + |
| 97 | +impl VectorIntoArray for DecimalVector { |
| 98 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 99 | + match_each_decimal_value_type!(self.decimal_type(), |D| { |
| 100 | + <D as NativeDecimalType>::downcast(self).into_array(dtype) |
| 101 | + }) |
| 102 | + } |
| 103 | +} |
| 104 | + |
| 105 | +impl<D: NativeDecimalType> VectorIntoArray for DVector<D> { |
| 106 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 107 | + assert!(matches!(dtype, DType::Decimal(_, _))); |
| 108 | + |
| 109 | + let nullability = dtype.nullability(); |
| 110 | + let dec_dtype = dtype |
| 111 | + .into_decimal_opt() |
| 112 | + .vortex_expect("expected decimal DType"); |
| 113 | + assert_eq!(dec_dtype.precision(), self.precision()); |
| 114 | + assert_eq!(dec_dtype.scale(), self.scale()); |
| 115 | + |
| 116 | + let (_ps, values, validity) = self.into_parts(); |
| 117 | + // SAFETY: vectors maintain all invariants required for array creation |
| 118 | + unsafe { |
| 119 | + DecimalArray::new_unchecked::<D>( |
| 120 | + values, |
| 121 | + dec_dtype, |
| 122 | + Validity::from_mask(validity, nullability), |
| 123 | + ) |
| 124 | + } |
| 125 | + .into_array() |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +impl<T: BinaryViewType> VectorIntoArray for BinaryViewVector<T> { |
| 130 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 131 | + assert!(matches!(dtype, DType::Utf8(_))); |
| 132 | + |
| 133 | + let (views, buffers, validity) = self.into_parts(); |
| 134 | + let validity = Validity::from_mask(validity, dtype.nullability()); |
| 135 | + |
| 136 | + let buffers = Arc::try_unwrap(buffers).unwrap_or_else(|b| (*b).clone()); |
| 137 | + |
| 138 | + // SAFETY: vectors maintain all invariants required for array creation |
| 139 | + unsafe { |
| 140 | + VarBinViewArray::new_unchecked(views, buffers.into_iter().collect(), dtype, validity) |
| 141 | + } |
| 142 | + .into_array() |
| 143 | + } |
| 144 | +} |
| 145 | + |
| 146 | +impl VectorIntoArray for ListViewVector { |
| 147 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 148 | + assert!(matches!(dtype, DType::List(_, _))); |
| 149 | + |
| 150 | + let (elements, offsets, sizes, validity) = self.into_parts(); |
| 151 | + let validity = Validity::from_mask(validity, dtype.nullability()); |
| 152 | + |
| 153 | + let elements_dtype = dtype |
| 154 | + .into_list_element_opt() |
| 155 | + .vortex_expect("expected list") |
| 156 | + .deref() |
| 157 | + .clone(); |
| 158 | + let elements = Arc::try_unwrap(elements) |
| 159 | + .unwrap_or_else(|e| (*e).clone()) |
| 160 | + .into_array(elements_dtype); |
| 161 | + |
| 162 | + let offsets_dtype = DType::Primitive(offsets.ptype(), NonNullable); |
| 163 | + let offsets = offsets.into_array(offsets_dtype); |
| 164 | + |
| 165 | + let sizes_dtype = DType::Primitive(sizes.ptype(), NonNullable); |
| 166 | + let sizes = sizes.into_array(sizes_dtype); |
| 167 | + |
| 168 | + // SAFETY: vectors maintain all invariants required for array creation |
| 169 | + unsafe { ListViewArray::new_unchecked(elements, offsets, sizes, validity) }.into_array() |
| 170 | + } |
| 171 | +} |
| 172 | + |
| 173 | +impl VectorIntoArray for FixedSizeListVector { |
| 174 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 175 | + assert!(matches!(dtype, DType::FixedSizeList(_, _, _))); |
| 176 | + |
| 177 | + let len = self.len(); |
| 178 | + let (elements, size, validity) = self.into_parts(); |
| 179 | + let validity = Validity::from_mask(validity, dtype.nullability()); |
| 180 | + |
| 181 | + let elements_dtype = dtype |
| 182 | + .into_fixed_size_list_element_opt() |
| 183 | + .vortex_expect("expected fixed size list") |
| 184 | + .deref() |
| 185 | + .clone(); |
| 186 | + let elements = Arc::try_unwrap(elements) |
| 187 | + .unwrap_or_else(|e| (*e).clone()) |
| 188 | + .into_array(elements_dtype); |
| 189 | + |
| 190 | + // SAFETY: vectors maintain all invariants required for array creation |
| 191 | + unsafe { FixedSizeListArray::new_unchecked(elements, size, validity, len) }.into_array() |
| 192 | + } |
| 193 | +} |
| 194 | + |
| 195 | +impl VectorIntoArray for StructVector { |
| 196 | + fn into_array(self, dtype: DType) -> ArrayRef { |
| 197 | + assert!(matches!(dtype, DType::Struct(_, _))); |
| 198 | + |
| 199 | + let len = self.len(); |
| 200 | + let (fields, validity) = self.into_parts(); |
| 201 | + let validity = Validity::from_mask(validity, dtype.nullability()); |
| 202 | + |
| 203 | + let struct_fields = dtype.into_struct_fields(); |
| 204 | + assert_eq!(fields.len(), struct_fields.nfields()); |
| 205 | + |
| 206 | + let field_arrays: Vec<ArrayRef> = Arc::try_unwrap(fields) |
| 207 | + .unwrap_or_else(|f| (*f).clone()) |
| 208 | + .into_iter() |
| 209 | + .zip(struct_fields.fields()) |
| 210 | + .map(|(field_vector, field_dtype)| field_vector.into_array(field_dtype)) |
| 211 | + .collect(); |
| 212 | + |
| 213 | + // SAFETY: vectors maintain all invariants required for array creation |
| 214 | + unsafe { StructArray::new_unchecked(field_arrays, struct_fields, len, validity) } |
| 215 | + .into_array() |
| 216 | + } |
| 217 | +} |
0 commit comments