Skip to content

Commit d589370

Browse files
authored
Feature: Add FixedSizeList support in vortex-scalar (#4380)
Adds `vortex-scalar` support for `FixedSizeList`. We reuse `ListScalar` to represent **both** `List` and `FixedSizeList` logical types, as there is no effective difference between a single scalar value of a list or fixed-size list. Also refactors the unit `tests` module to be a bit more clear in what each submodule is actually testing, plus adds several tests for `FixedSizeList`. Also fixed some `no_coercion` tests that checked the wrong things. For reviewers: it might be a good idea to review each commit separately --------- Signed-off-by: Connor Tsui <[email protected]>
1 parent 1a28fa2 commit d589370

File tree

8 files changed

+3327
-1223
lines changed

8 files changed

+3327
-1223
lines changed

vortex-scalar/src/list.rs

Lines changed: 116 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,26 @@
33

44
use std::fmt::{Display, Formatter};
55
use std::hash::Hash;
6-
use std::ops::Deref;
76
use std::sync::Arc;
87

98
use itertools::Itertools as _;
109
use vortex_dtype::{DType, Nullability};
11-
use vortex_error::{VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_panic};
10+
use vortex_error::{
11+
VortexError, VortexExpect as _, VortexResult, vortex_bail, vortex_err, vortex_panic,
12+
};
1213

1314
use crate::{InnerScalarValue, Scalar, ScalarValue};
1415

15-
/// A scalar value representing a list (array) of elements.
16+
/// A scalar value representing a list or fixed-size list (array) of elements.
1617
///
17-
/// This type provides a view into a list scalar value, which can contain
18-
/// zero or more elements of the same type, or be null.
18+
/// We use the same [`ListScalar`] to represent both variants since a single list scalar's data is
19+
/// identical to a single fixed-size list scalar.
20+
///
21+
/// This type provides a view into a list or fixed-size list scalar value which can contain zero or
22+
/// more elements of the same type, or be null. If the `dtype` is a [`FixedSizeList`], then the
23+
/// number of `elements` is equal to the `size` field of the [`FixedSizeList`].
24+
///
25+
/// [`FixedSizeList`]: DType::FixedSizeList
1926
#[derive(Debug)]
2027
pub struct ListScalar<'a> {
2128
dtype: &'a DType,
@@ -28,9 +35,16 @@ impl Display for ListScalar<'_> {
2835
match &self.elements {
2936
None => write!(f, "null"),
3037
Some(elems) => {
38+
let fixed_size_list_str: &dyn Display =
39+
if let DType::FixedSizeList(_, size, _) = self.dtype {
40+
&format!("fixed_size<{size}>")
41+
} else {
42+
&""
43+
};
44+
3145
write!(
3246
f,
33-
"[{}]",
47+
"{fixed_size_list_str}[{}]",
3448
elems
3549
.iter()
3650
.map(|e| Scalar::new(self.element_dtype().clone(), e.clone()))
@@ -101,10 +115,10 @@ impl<'a> ListScalar<'a> {
101115

102116
/// Returns the data type of the list's elements.
103117
pub fn element_dtype(&self) -> &DType {
104-
let DType::List(element_type, _) = self.dtype() else {
105-
unreachable!();
106-
};
107-
(*element_type).deref()
118+
self.dtype
119+
.as_list_element_opt()
120+
.unwrap_or_else(|| vortex_panic!("`ListScalar` somehow had dtype {}", self.dtype))
121+
.as_ref()
108122
}
109123

110124
/// Returns the element at the given index as a scalar.
@@ -129,14 +143,35 @@ impl<'a> ListScalar<'a> {
129143
})
130144
}
131145

146+
/// Casts the list to the target [`DType`].
147+
///
148+
/// # Panics
149+
///
150+
/// Panics if the target [`DType`] is not a [`List`]: or [`FixedSizeList`], or if trying to cast
151+
/// to a [`FixedSizeList`] with the incorrect number of elements.
152+
///
153+
/// [`List`]: DType::List
154+
/// [`FixedSizeList`]: DType::FixedSizeList
132155
pub(crate) fn cast(&self, dtype: &DType) -> VortexResult<Scalar> {
133-
let DType::List(element_dtype, ..) = dtype else {
156+
let target_element_dtype = dtype
157+
.as_list_element_opt()
158+
.ok_or_else(|| {
159+
vortex_err!(
160+
"Cannot cast {} to {}: list can only be cast to a list or fixed-size list",
161+
self.dtype(),
162+
dtype
163+
)
164+
})?
165+
.as_ref();
166+
167+
if let DType::FixedSizeList(_, size, _) = dtype
168+
&& *size as usize != self.len()
169+
{
134170
vortex_bail!(
135-
"Cannot cast {} to {}: list can only be cast to list",
136-
self.dtype(),
137-
dtype
171+
"tried to cast to a `FixedSizeList[{size}]` but had {} elements",
172+
self.len()
138173
)
139-
};
174+
}
140175

141176
Ok(Scalar::new(
142177
dtype.clone(),
@@ -146,61 +181,101 @@ impl<'a> ListScalar<'a> {
146181
.vortex_expect("nullness handled in Scalar::cast")
147182
.iter()
148183
.map(|element| {
184+
// Recursively cast the elements of the list.
149185
Scalar::new(DType::clone(self.element_dtype), element.clone())
150-
.cast(element_dtype)
186+
.cast(target_element_dtype)
151187
.map(|x| x.value().clone())
152188
})
153-
.process_results(|iter| iter.collect())?,
189+
.collect::<VortexResult<Arc<[ScalarValue]>>>()?,
154190
)),
155191
))
156192
}
157193
}
158194

195+
/// A helper enum for creating a [`ListScalar`].
196+
enum ListKind {
197+
Variable,
198+
FixedSize,
199+
}
200+
201+
/// Helper functions to create a [`ListScalar`] as a [`Scalar`].
159202
impl Scalar {
203+
fn create_list(
204+
element_dtype: impl Into<Arc<DType>>,
205+
children: Vec<Scalar>,
206+
nullability: Nullability,
207+
list_kind: ListKind,
208+
) -> Self {
209+
let element_dtype = element_dtype.into();
210+
211+
let children: Arc<[ScalarValue]> = children
212+
.into_iter()
213+
.map(|child| {
214+
if child.dtype() != &*element_dtype {
215+
vortex_panic!(
216+
"tried to create list of {} with values of type {}",
217+
element_dtype,
218+
child.dtype()
219+
);
220+
}
221+
child.value
222+
})
223+
.collect();
224+
let size: u32 = children
225+
.len()
226+
.try_into()
227+
.vortex_expect("tried to create a list that was too large");
228+
229+
let dtype = match list_kind {
230+
ListKind::Variable => DType::List(element_dtype, nullability),
231+
ListKind::FixedSize => DType::FixedSizeList(element_dtype, size, nullability),
232+
};
233+
234+
Self::new(dtype, ScalarValue(InnerScalarValue::List(children)))
235+
}
236+
160237
/// Creates a new list scalar with the given element type and children.
161238
///
162239
/// # Panics
163240
///
164-
/// Panics if any child scalar has a different type than the element type.
241+
/// Panics if any child scalar has a different type than the element type, or if there are too
242+
/// many children.
165243
pub fn list(
166244
element_dtype: impl Into<Arc<DType>>,
167245
children: Vec<Scalar>,
168246
nullability: Nullability,
169247
) -> Self {
170-
let element_dtype = element_dtype.into();
171-
for child in &children {
172-
if child.dtype() != &*element_dtype {
173-
vortex_panic!(
174-
"tried to create list of {} with values of type {}",
175-
element_dtype,
176-
child.dtype()
177-
);
178-
}
179-
}
180-
Self::new(
181-
DType::List(element_dtype, nullability),
182-
ScalarValue(InnerScalarValue::List(
183-
children.into_iter().map(|x| x.value).collect(),
184-
)),
185-
)
248+
Self::create_list(element_dtype, children, nullability, ListKind::Variable)
186249
}
187250

188251
/// Creates a new empty list scalar with the given element type.
189252
pub fn list_empty(element_dtype: Arc<DType>, nullability: Nullability) -> Self {
190-
Self::new(
191-
DType::List(element_dtype, nullability),
192-
ScalarValue(InnerScalarValue::List(vec![].into())),
193-
)
253+
Self::create_list(element_dtype, vec![], nullability, ListKind::Variable)
254+
}
255+
256+
/// Creates a new fixed-size list scalar with the given element type and children.
257+
///
258+
/// # Panics
259+
///
260+
/// Panics if any child scalar has a different type than the element type, or if there are too
261+
/// many children.
262+
pub fn fixed_size_list(
263+
element_dtype: impl Into<Arc<DType>>,
264+
children: Vec<Scalar>,
265+
nullability: Nullability,
266+
) -> Self {
267+
Self::create_list(element_dtype, children, nullability, ListKind::FixedSize)
194268
}
195269
}
196270

197271
impl<'a> TryFrom<&'a Scalar> for ListScalar<'a> {
198272
type Error = VortexError;
199273

200274
fn try_from(value: &'a Scalar) -> Result<Self, Self::Error> {
201-
let DType::List(element_dtype, ..) = value.dtype() else {
202-
vortex_bail!("Expected list scalar, found {}", value.dtype())
203-
};
275+
let element_dtype = value
276+
.dtype()
277+
.as_list_element_opt()
278+
.ok_or_else(|| vortex_err!("Expected list scalar, found {}", value.dtype()))?;
204279

205280
Ok(Self {
206281
dtype: value.dtype(),

vortex-scalar/src/scalar.rs

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,7 @@ impl Scalar {
125125
DType::Utf8(_) => self.as_utf8().cast(target),
126126
DType::Binary(_) => self.as_binary().cast(target),
127127
DType::Struct(..) => self.as_struct().cast(target),
128-
DType::List(..) => self.as_list().cast(target),
129-
DType::FixedSizeList(..) => unimplemented!("TODO(connor)[FixedSizeList]"),
128+
DType::List(..) | DType::FixedSizeList(..) => self.as_list().cast(target),
130129
DType::Extension(..) => self.as_extension().cast(target),
131130
}
132131
}
@@ -163,12 +162,11 @@ impl Scalar {
163162
.fields()
164163
.map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
165164
.unwrap_or_default(),
166-
DType::List(..) => self
165+
DType::List(..) | DType::FixedSizeList(..) => self
167166
.as_list()
168167
.elements()
169168
.map(|fields| fields.into_iter().map(|f| f.nbytes()).sum::<usize>())
170169
.unwrap_or_default(),
171-
DType::FixedSizeList(..) => unimplemented!("TODO(connor)[FixedSizeList]"),
172170
DType::Extension(_ext_dtype) => self.as_extension().storage().nbytes(),
173171
}
174172
}
@@ -199,7 +197,12 @@ impl Scalar {
199197
Self::struct_(DType::Struct(sf, nullability), fields)
200198
}
201199
DType::List(edt, nullability) => Self::list(edt, vec![], nullability),
202-
DType::FixedSizeList(..) => unimplemented!("TODO(connor)[FixedSizeList]"),
200+
DType::FixedSizeList(edt, size, nullability) => {
201+
let elements = (0..size)
202+
.map(|_| Scalar::default_value(edt.as_ref().clone()))
203+
.collect();
204+
Self::list(edt, elements, nullability)
205+
}
203206
DType::Extension(dt) => {
204207
let scalar = Self::default_value(dt.storage_dtype().clone());
205208
Self::extension(dt, scalar)
@@ -296,6 +299,9 @@ impl Scalar {
296299

297300
/// Returns a view of the scalar as a list scalar.
298301
///
302+
/// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and
303+
/// [`DType::FixedSizeList`].
304+
///
299305
/// # Panics
300306
///
301307
/// Panics if the scalar is not a list type.
@@ -304,8 +310,11 @@ impl Scalar {
304310
}
305311

306312
/// Returns a view of the scalar as a list scalar if it has a list type.
313+
///
314+
/// Note that we use [`ListScalar`] to represent **both** [`DType::List`] and
315+
/// [`DType::FixedSizeList`].
307316
pub fn as_list_opt(&self) -> Option<ListScalar<'_>> {
308-
matches!(self.dtype, DType::List(..)).then(|| self.as_list())
317+
matches!(self.dtype, DType::List(..) | DType::FixedSizeList(..)).then(|| self.as_list())
309318
}
310319

311320
/// Returns a view of the scalar as an extension scalar.
@@ -397,8 +406,7 @@ impl PartialEq for Scalar {
397406
DType::Utf8(_) => self.as_utf8() == other.as_utf8(),
398407
DType::Binary(_) => self.as_binary() == other.as_binary(),
399408
DType::Struct(..) => self.as_struct() == other.as_struct(),
400-
DType::List(..) => self.as_list() == other.as_list(),
401-
DType::FixedSizeList(..) => unimplemented!("TODO(connor)[FixedSizeList]"),
409+
DType::List(..) | DType::FixedSizeList(..) => self.as_list() == other.as_list(),
402410
DType::Extension(_) => self.as_extension() == other.as_extension(),
403411
}
404412
}
@@ -447,8 +455,9 @@ impl PartialOrd for Scalar {
447455
DType::Utf8(_) => self.as_utf8().partial_cmp(&other.as_utf8()),
448456
DType::Binary(_) => self.as_binary().partial_cmp(&other.as_binary()),
449457
DType::Struct(..) => self.as_struct().partial_cmp(&other.as_struct()),
450-
DType::List(..) => self.as_list().partial_cmp(&other.as_list()),
451-
DType::FixedSizeList(..) => unimplemented!("TODO(connor)[FixedSizeList]"),
458+
DType::List(..) | DType::FixedSizeList(..) => {
459+
self.as_list().partial_cmp(&other.as_list())
460+
}
452461
DType::Extension(_) => self.as_extension().partial_cmp(&other.as_extension()),
453462
}
454463
}
@@ -464,8 +473,7 @@ impl Hash for Scalar {
464473
DType::Utf8(_) => self.as_utf8().hash(state),
465474
DType::Binary(_) => self.as_binary().hash(state),
466475
DType::Struct(..) => self.as_struct().hash(state),
467-
DType::List(..) => self.as_list().hash(state),
468-
DType::FixedSizeList(..) => unimplemented!("TODO(connor)[FixedSizeList]"),
476+
DType::List(..) | DType::FixedSizeList(..) => self.as_list().hash(state),
469477
DType::Extension(_) => self.as_extension().hash(state),
470478
}
471479
}

0 commit comments

Comments
 (0)