|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright the Vortex contributors |
| 3 | + |
| 4 | +//! A more configurable variant of the `StructStrategy` that allows overwriting |
| 5 | +//! specific leaf fields with custom write strategies. |
| 6 | +
|
| 7 | +use std::sync::Arc; |
| 8 | + |
| 9 | +use async_trait::async_trait; |
| 10 | +use futures::StreamExt; |
| 11 | +use futures::TryStreamExt; |
| 12 | +use futures::future::try_join_all; |
| 13 | +use futures::pin_mut; |
| 14 | +use itertools::Itertools; |
| 15 | +use vortex_array::ArrayContext; |
| 16 | +use vortex_array::ArrayRef; |
| 17 | +use vortex_array::IntoArray; |
| 18 | +use vortex_array::ToCanonical; |
| 19 | +use vortex_dtype::DType; |
| 20 | +use vortex_dtype::Field; |
| 21 | +use vortex_dtype::FieldName; |
| 22 | +use vortex_dtype::FieldPath; |
| 23 | +use vortex_dtype::Nullability; |
| 24 | +use vortex_error::VortexError; |
| 25 | +use vortex_error::VortexResult; |
| 26 | +use vortex_error::vortex_bail; |
| 27 | +use vortex_io::kanal_ext::KanalExt; |
| 28 | +use vortex_io::runtime::Handle; |
| 29 | +use vortex_utils::aliases::DefaultHashBuilder; |
| 30 | +use vortex_utils::aliases::hash_map::HashMap; |
| 31 | +use vortex_utils::aliases::hash_set::HashSet; |
| 32 | + |
| 33 | +use crate::IntoLayout; |
| 34 | +use crate::LayoutRef; |
| 35 | +use crate::LayoutStrategy; |
| 36 | +use crate::layouts::struct_::StructLayout; |
| 37 | +use crate::segments::SegmentSinkRef; |
| 38 | +use crate::sequence::SendableSequentialStream; |
| 39 | +use crate::sequence::SequenceId; |
| 40 | +use crate::sequence::SequencePointer; |
| 41 | +use crate::sequence::SequentialStreamAdapter; |
| 42 | +use crate::sequence::SequentialStreamExt; |
| 43 | + |
| 44 | +pub struct PathStrategy { |
| 45 | + // A set of leaf field overrides, e.g. to force one column to be compact-compressed. |
| 46 | + leaf_writers: HashMap<FieldPath, Arc<dyn LayoutStrategy>>, |
| 47 | + // The writer for any validity arrays that may be present |
| 48 | + validity: Arc<dyn LayoutStrategy>, |
| 49 | + // The fallback writer for any fields that do not have an explicit writer set in `leaf_writers` |
| 50 | + fallback: Arc<dyn LayoutStrategy>, |
| 51 | +} |
| 52 | + |
| 53 | +impl PathStrategy { |
| 54 | + /// Create a new field writer with the given path validity |
| 55 | + pub fn new( |
| 56 | + leaf_writers: HashMap<FieldPath, Arc<dyn LayoutStrategy>>, |
| 57 | + validity: Arc<dyn LayoutStrategy>, |
| 58 | + fallback: Arc<dyn LayoutStrategy>, |
| 59 | + ) -> Self { |
| 60 | + Self { |
| 61 | + leaf_writers, |
| 62 | + validity, |
| 63 | + fallback, |
| 64 | + } |
| 65 | + } |
| 66 | +} |
| 67 | + |
| 68 | +impl PathStrategy { |
| 69 | + fn descend(&self, field: &Field) -> Self { |
| 70 | + // Start with the existing set of overrides, then only retain the ones that contain |
| 71 | + // the current field |
| 72 | + let mut new_writers = self.leaf_writers.clone(); |
| 73 | + new_writers.retain(|k, _| k.starts_with_field(field)); |
| 74 | + |
| 75 | + Self { |
| 76 | + leaf_writers: new_writers, |
| 77 | + validity: self.validity.clone(), |
| 78 | + fallback: self.fallback.clone(), |
| 79 | + } |
| 80 | + } |
| 81 | +} |
| 82 | + |
| 83 | +/// Specialized strategy for when we exactly know the input schema. |
| 84 | +#[async_trait] |
| 85 | +impl LayoutStrategy for PathStrategy { |
| 86 | + async fn write_stream( |
| 87 | + &self, |
| 88 | + ctx: ArrayContext, |
| 89 | + segment_sink: SegmentSinkRef, |
| 90 | + stream: SendableSequentialStream, |
| 91 | + mut eof: SequencePointer, |
| 92 | + handle: Handle, |
| 93 | + ) -> VortexResult<LayoutRef> { |
| 94 | + let dtype = stream.dtype().clone(); |
| 95 | + let struct_dtype = dtype.as_struct_fields(); |
| 96 | + |
| 97 | + // Check for unique field names at write time. |
| 98 | + if HashSet::<_, DefaultHashBuilder>::from_iter(struct_dtype.names().iter()).len() |
| 99 | + != struct_dtype.names().len() |
| 100 | + { |
| 101 | + vortex_bail!("StructLayout must have unique field names"); |
| 102 | + } |
| 103 | + let is_nullable = dtype.is_nullable(); |
| 104 | + |
| 105 | + // Optimization: when there are no fields, don't spawn any work and just write a trivial |
| 106 | + // StructLayout. |
| 107 | + if struct_dtype.nfields() == 0 && !is_nullable { |
| 108 | + let row_count = stream |
| 109 | + .try_fold( |
| 110 | + 0u64, |
| 111 | + |acc, (_, arr)| async move { Ok(acc + arr.len() as u64) }, |
| 112 | + ) |
| 113 | + .await?; |
| 114 | + return Ok(StructLayout::new(row_count, dtype, vec![]).into_layout()); |
| 115 | + } |
| 116 | + |
| 117 | + // stream<struct_chunk> -> stream<vec<column_chunk>> |
| 118 | + let columns_vec_stream = stream.map(move |chunk| { |
| 119 | + let (sequence_id, chunk) = chunk?; |
| 120 | + let mut sequence_pointer = sequence_id.descend(); |
| 121 | + let struct_chunk = chunk.to_struct(); |
| 122 | + let mut columns: Vec<(SequenceId, ArrayRef)> = Vec::new(); |
| 123 | + if is_nullable { |
| 124 | + columns.push(( |
| 125 | + sequence_pointer.advance(), |
| 126 | + chunk.validity_mask().into_array(), |
| 127 | + )); |
| 128 | + } |
| 129 | + |
| 130 | + columns.extend( |
| 131 | + struct_chunk |
| 132 | + .fields() |
| 133 | + .iter() |
| 134 | + .map(|field| (sequence_pointer.advance(), field.to_array())), |
| 135 | + ); |
| 136 | + |
| 137 | + Ok(columns) |
| 138 | + }); |
| 139 | + |
| 140 | + let mut stream_count = struct_dtype.nfields(); |
| 141 | + if is_nullable { |
| 142 | + stream_count += 1; |
| 143 | + } |
| 144 | + |
| 145 | + let (column_streams_tx, column_streams_rx): (Vec<_>, Vec<_>) = |
| 146 | + (0..stream_count).map(|_| kanal::bounded_async(1)).unzip(); |
| 147 | + |
| 148 | + // Spawn a task to fan out column chunks to their respective transposed streams |
| 149 | + handle |
| 150 | + .spawn(async move { |
| 151 | + pin_mut!(columns_vec_stream); |
| 152 | + while let Some(result) = columns_vec_stream.next().await { |
| 153 | + match result { |
| 154 | + Ok(columns) => { |
| 155 | + for (tx, column) in column_streams_tx.iter().zip_eq(columns.into_iter()) |
| 156 | + { |
| 157 | + let _ = tx.send(Ok(column)).await; |
| 158 | + } |
| 159 | + } |
| 160 | + Err(e) => { |
| 161 | + let e: Arc<VortexError> = Arc::new(e); |
| 162 | + for tx in column_streams_tx.iter() { |
| 163 | + let _ = tx.send(Err(VortexError::from(e.clone()))).await; |
| 164 | + } |
| 165 | + break; |
| 166 | + } |
| 167 | + } |
| 168 | + } |
| 169 | + }) |
| 170 | + .detach(); |
| 171 | + |
| 172 | + // First child column is the validity, subsequence children are the individual struct fields |
| 173 | + let column_dtypes: Vec<DType> = if is_nullable { |
| 174 | + std::iter::once(DType::Bool(Nullability::NonNullable)) |
| 175 | + .chain(struct_dtype.fields()) |
| 176 | + .collect() |
| 177 | + } else { |
| 178 | + struct_dtype.fields().collect() |
| 179 | + }; |
| 180 | + |
| 181 | + let column_names: Vec<FieldName> = if is_nullable { |
| 182 | + std::iter::once(FieldName::from("__validity")) |
| 183 | + .chain(struct_dtype.names().iter().cloned()) |
| 184 | + .collect() |
| 185 | + } else { |
| 186 | + struct_dtype.names().iter().cloned().collect() |
| 187 | + }; |
| 188 | + |
| 189 | + let layout_futures: Vec<_> = column_dtypes |
| 190 | + .into_iter() |
| 191 | + .zip_eq(column_streams_rx) |
| 192 | + .zip_eq(column_names) |
| 193 | + .enumerate() |
| 194 | + .map(move |(index, ((dtype, recv), name))| { |
| 195 | + println!("PathStrategy visiting {name}"); |
| 196 | + let column_stream = |
| 197 | + SequentialStreamAdapter::new(dtype.clone(), recv.into_stream().boxed()) |
| 198 | + .sendable(); |
| 199 | + let child_eof = eof.split_off(); |
| 200 | + let field = Field::Name(name.clone()); |
| 201 | + handle.spawn_nested(|h| { |
| 202 | + let fallback = self.fallback.clone(); |
| 203 | + let validity = self.validity.clone(); |
| 204 | + // descend further and try with new fields |
| 205 | + let writer = self |
| 206 | + .leaf_writers |
| 207 | + .get(&FieldPath::from_name(name)) |
| 208 | + .cloned() |
| 209 | + .unwrap_or_else(|| { |
| 210 | + if dtype.is_struct() { |
| 211 | + // Step into the field path for struct columns |
| 212 | + Arc::new(self.descend(&field)) |
| 213 | + } else { |
| 214 | + // Use fallback for leaf columns |
| 215 | + self.fallback.clone() |
| 216 | + } |
| 217 | + }); |
| 218 | + let ctx = ctx.clone(); |
| 219 | + let dtype = dtype.clone(); |
| 220 | + let segment_sink = segment_sink.clone(); |
| 221 | + |
| 222 | + async move { |
| 223 | + // If we have a matching writer, we use it. |
| 224 | + // Otherwise, we descend into a new modified one. |
| 225 | + // Write validity stream |
| 226 | + if index == 0 && is_nullable { |
| 227 | + validity |
| 228 | + .write_stream(ctx, segment_sink, column_stream, child_eof, h) |
| 229 | + .await |
| 230 | + } else { |
| 231 | + // Use the underlying writer, otherwise use the fallback writer. |
| 232 | + writer |
| 233 | + .write_stream(ctx, segment_sink, column_stream, child_eof, h) |
| 234 | + .await |
| 235 | + } |
| 236 | + } |
| 237 | + }) |
| 238 | + }) |
| 239 | + .collect(); |
| 240 | + |
| 241 | + let column_layouts = try_join_all(layout_futures).await?; |
| 242 | + // TODO(os): transposed stream could count row counts as well, |
| 243 | + // This must hold though, all columns must have the same row count of the struct layout |
| 244 | + let row_count = column_layouts.first().map(|l| l.row_count()).unwrap_or(0); |
| 245 | + Ok(StructLayout::new(row_count, dtype, column_layouts).into_layout()) |
| 246 | + } |
| 247 | +} |
0 commit comments