Skip to content

Commit 797160d

Browse files
authored
Adds a GetItem operator array (#5110)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent bef6443 commit 797160d

File tree

6 files changed

+326
-4
lines changed

6 files changed

+326
-4
lines changed
Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::hash::{Hash, Hasher};
5+
6+
use vortex_compute::mask::MaskValidity;
7+
use vortex_dtype::{DType, FieldName};
8+
use vortex_error::{VortexResult, vortex_bail, vortex_err};
9+
use vortex_vector::VectorOps;
10+
11+
use crate::execution::{BatchKernelRef, BindCtx, kernel};
12+
use crate::stats::{ArrayStats, StatsSetRef};
13+
use crate::vtable::{ArrayVTable, NotSupported, OperatorVTable, VTable, VisitorVTable};
14+
use crate::{
15+
Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash, ArrayRef, EncodingId,
16+
EncodingRef, Precision, vtable,
17+
};
18+
19+
vtable!(GetItem);
20+
21+
/// An array that extracts the given field from a Struct array.
22+
///
23+
/// The validity of the field is intersected with the validity of the parent Struct array.
24+
#[derive(Debug, Clone)]
25+
pub struct GetItemArray {
26+
child: ArrayRef,
27+
field: FieldName,
28+
dtype: DType,
29+
stats: ArrayStats,
30+
}
31+
32+
impl GetItemArray {
33+
/// Create a new get_item array.
34+
pub fn try_new(child: ArrayRef, field: FieldName) -> VortexResult<Self> {
35+
let DType::Struct(fields, _) = child.dtype() else {
36+
vortex_bail!(
37+
"GetItem can only be applied to Struct arrays, got {}",
38+
child.dtype()
39+
);
40+
};
41+
42+
let Some(dtype) = fields.field(&field) else {
43+
vortex_bail!("Field '{}' does not exist in Struct array", field);
44+
};
45+
46+
// Make the field nullable if the parent struct is nullable
47+
let dtype = dtype.with_nullability(dtype.nullability() | child.dtype().nullability());
48+
49+
Ok(Self {
50+
child,
51+
field,
52+
dtype,
53+
stats: ArrayStats::default(),
54+
})
55+
}
56+
}
57+
58+
#[derive(Debug, Clone)]
59+
pub struct GetItemEncoding;
60+
61+
impl VTable for GetItemVTable {
62+
type Array = GetItemArray;
63+
type Encoding = GetItemEncoding;
64+
type ArrayVTable = Self;
65+
type CanonicalVTable = NotSupported;
66+
type OperationsVTable = NotSupported;
67+
type ValidityVTable = NotSupported;
68+
type VisitorVTable = Self;
69+
type ComputeVTable = NotSupported;
70+
type EncodeVTable = NotSupported;
71+
type SerdeVTable = NotSupported;
72+
type OperatorVTable = Self;
73+
74+
fn id(_encoding: &Self::Encoding) -> EncodingId {
75+
EncodingId::from("vortex.get_item")
76+
}
77+
78+
fn encoding(_array: &Self::Array) -> EncodingRef {
79+
EncodingRef::from(GetItemEncoding.as_ref())
80+
}
81+
}
82+
83+
impl ArrayVTable<GetItemVTable> for GetItemVTable {
84+
fn len(array: &GetItemArray) -> usize {
85+
array.child.len()
86+
}
87+
88+
fn dtype(array: &GetItemArray) -> &DType {
89+
&array.dtype
90+
}
91+
92+
fn stats(array: &GetItemArray) -> StatsSetRef<'_> {
93+
array.stats.to_ref(array.as_ref())
94+
}
95+
96+
fn array_hash<H: Hasher>(array: &GetItemArray, state: &mut H, precision: Precision) {
97+
array.child.array_hash(state, precision);
98+
array.field.hash(state);
99+
}
100+
101+
fn array_eq(array: &GetItemArray, other: &GetItemArray, precision: Precision) -> bool {
102+
array.child.array_eq(&other.child, precision) && array.field == other.field
103+
}
104+
}
105+
106+
impl VisitorVTable<GetItemVTable> for GetItemVTable {
107+
fn visit_buffers(_array: &GetItemArray, _visitor: &mut dyn ArrayBufferVisitor) {
108+
// No buffers
109+
}
110+
111+
fn visit_children(array: &GetItemArray, visitor: &mut dyn ArrayChildVisitor) {
112+
visitor.visit_child("struct", array.child.as_ref());
113+
}
114+
}
115+
116+
impl OperatorVTable<GetItemVTable> for GetItemVTable {
117+
fn bind(
118+
array: &GetItemArray,
119+
selection: Option<&ArrayRef>,
120+
ctx: &mut dyn BindCtx,
121+
) -> VortexResult<BatchKernelRef> {
122+
let child = ctx.bind(&array.child, selection)?;
123+
124+
// Find the index of the field in the struct
125+
let idx = array
126+
.child
127+
.dtype()
128+
.as_struct_fields()
129+
.find(&array.field)
130+
.ok_or_else(|| vortex_err!("Field '{}' does not exist in Struct array", array.field))?;
131+
132+
Ok(kernel(move || {
133+
let struct_ = child.execute()?.into_struct();
134+
135+
// We must intersect the validity with that of the parent struct
136+
let field = struct_.fields()[idx].clone();
137+
let field = MaskValidity::mask_validity(field, struct_.validity());
138+
139+
Ok(field)
140+
}))
141+
}
142+
}
143+
144+
#[cfg(test)]
145+
mod tests {
146+
use vortex_buffer::{bitbuffer, buffer};
147+
use vortex_dtype::{FieldNames, Nullability, PTypeDowncast};
148+
use vortex_vector::VectorOps;
149+
150+
use crate::arrays::{BoolArray, PrimitiveArray, StructArray};
151+
use crate::compute::arrays::get_item::GetItemArray;
152+
use crate::validity::Validity;
153+
use crate::{ArrayOperator, IntoArray};
154+
155+
#[test]
156+
fn test_get_item_basic() {
157+
// Create a non-nullable struct with non-nullable fields
158+
let int_field = PrimitiveArray::from_iter([10i32, 20, 30, 40]);
159+
let bool_field = BoolArray::from_iter([true, false, true, false]);
160+
161+
let struct_array = StructArray::try_new(
162+
FieldNames::from(["numbers", "flags"]),
163+
vec![int_field.into_array(), bool_field.into_array()],
164+
4,
165+
Validity::NonNullable,
166+
)
167+
.unwrap()
168+
.into_array();
169+
170+
// Extract the "numbers" field
171+
let get_item = GetItemArray::try_new(struct_array, "numbers".into())
172+
.unwrap()
173+
.into_array();
174+
175+
// Verify the dtype is non-nullable
176+
assert_eq!(get_item.dtype().nullability(), Nullability::NonNullable);
177+
178+
// Execute and verify the values
179+
let result = get_item.execute().unwrap().into_primitive().into_i32();
180+
assert_eq!(result.elements(), &buffer![10i32, 20, 30, 40]);
181+
}
182+
183+
#[test]
184+
fn test_get_item_nullable_struct_nonnullable_field() {
185+
// Create a nullable struct with non-nullable field
186+
// The result should be nullable because the struct is nullable
187+
let int_field = PrimitiveArray::from_iter([10i32, 20, 30, 40]);
188+
189+
let struct_array = StructArray::try_new(
190+
FieldNames::from(["numbers"]),
191+
vec![int_field.into_array()],
192+
4,
193+
Validity::from_iter([true, false, true, false]),
194+
)
195+
.unwrap()
196+
.into_array();
197+
198+
// Extract the "numbers" field
199+
let get_item = GetItemArray::try_new(struct_array, "numbers".into())
200+
.unwrap()
201+
.into_array();
202+
203+
// The dtype should be nullable even though the field itself is non-nullable
204+
assert_eq!(get_item.dtype().nullability(), Nullability::Nullable);
205+
206+
// Execute and verify values and validity
207+
let result = get_item.execute().unwrap().into_primitive().into_i32();
208+
assert_eq!(result.elements(), &buffer![10i32, 20, 30, 40]);
209+
210+
// Check that validity was properly intersected
211+
// Elements at indices 1 and 3 should be null due to struct validity
212+
assert_eq!(result.validity().to_bit_buffer(), bitbuffer![1 0 1 0]);
213+
}
214+
215+
#[test]
216+
fn test_get_item_with_selection() {
217+
// Create a struct with multiple fields
218+
let int_field = PrimitiveArray::from_iter([10i32, 20, 30, 40, 50, 60]);
219+
let bool_field = BoolArray::from_iter([true, false, true, false, true, false]);
220+
221+
let struct_array = StructArray::try_new(
222+
FieldNames::from(["numbers", "flags"]),
223+
vec![int_field.into_array(), bool_field.into_array()],
224+
6,
225+
Validity::from_iter([true, true, false, true, true, false]),
226+
)
227+
.unwrap()
228+
.into_array();
229+
230+
// Extract the "numbers" field
231+
let get_item = GetItemArray::try_new(struct_array, "numbers".into())
232+
.unwrap()
233+
.into_array();
234+
235+
// Apply selection mask [1 0 1 0 1 0] => select indices 0, 2, 4
236+
let selection = bitbuffer![1 0 1 0 1 0].into_array();
237+
let result = get_item
238+
.execute_with_selection(Some(&selection))
239+
.unwrap()
240+
.into_primitive()
241+
.into_i32();
242+
243+
// Should have 3 elements: indices 0, 2, 4
244+
assert_eq!(result.len(), 3);
245+
assert_eq!(result.elements(), &buffer![10i32, 30, 50]);
246+
247+
// Check validity: index 0 is valid, index 2 is null (struct), index 4 is valid
248+
assert_eq!(result.validity().to_bit_buffer(), bitbuffer![1 0 1]);
249+
}
250+
251+
#[test]
252+
fn test_get_item_intersects_validity() {
253+
// Test that field validity is intersected with struct validity
254+
// Field has nulls at indices 1, 3
255+
let int_field =
256+
PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, Some(50)]);
257+
258+
// Struct has nulls at indices 2, 4
259+
let struct_array = StructArray::try_new(
260+
FieldNames::from(["values"]),
261+
vec![int_field.into_array()],
262+
5,
263+
Validity::from_iter([true, true, false, true, false]),
264+
)
265+
.unwrap()
266+
.into_array();
267+
268+
let get_item = GetItemArray::try_new(struct_array, "values".into())
269+
.unwrap()
270+
.into_array();
271+
272+
let result = get_item.execute().unwrap().into_primitive().into_i32();
273+
274+
// Verify that nulls are correctly combined:
275+
// Index 0: valid (both valid)
276+
// Index 1: null (field null)
277+
// Index 2: null (struct null)
278+
// Index 3: null (field null)
279+
// Index 4: null (struct null)
280+
assert_eq!(result.validity().to_bit_buffer(), bitbuffer![1 0 0 0 0]);
281+
}
282+
283+
#[test]
284+
fn test_get_item_bool_field() {
285+
// Test extracting a boolean field
286+
let bool_field = BoolArray::from_iter([true, false, true, false]);
287+
288+
let struct_array = StructArray::try_new(
289+
FieldNames::from(["flags"]),
290+
vec![bool_field.into_array()],
291+
4,
292+
Validity::from_iter([true, false, true, true]),
293+
)
294+
.unwrap()
295+
.into_array();
296+
297+
let get_item = GetItemArray::try_new(struct_array, "flags".into())
298+
.unwrap()
299+
.into_array();
300+
301+
let result = get_item.execute().unwrap().into_bool();
302+
303+
// Verify values
304+
assert_eq!(result.bits(), &bitbuffer![1 0 1 0]);
305+
306+
// Verify validity (index 1 should be null from struct)
307+
assert_eq!(result.validity().to_bit_buffer(), bitbuffer![1 0 1 1]);
308+
}
309+
}

vortex-array/src/compute/arrays/is_not_null.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ impl VTable for IsNotNullVTable {
5252
type OperatorVTable = Self;
5353

5454
fn id(_encoding: &Self::Encoding) -> EncodingId {
55-
EncodingId::from("vortex.is_null")
55+
EncodingId::from("vortex.is_not_null")
5656
}
5757

5858
fn encoding(_array: &Self::Array) -> EncodingRef {
@@ -62,7 +62,7 @@ impl VTable for IsNotNullVTable {
6262

6363
impl ArrayVTable<IsNotNullVTable> for IsNotNullVTable {
6464
fn len(array: &IsNotNullArray) -> usize {
65-
array.len()
65+
array.child.len()
6666
}
6767

6868
fn dtype(_array: &IsNotNullArray) -> &DType {

vortex-array/src/compute/arrays/is_null.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl VTable for IsNullVTable {
6363

6464
impl ArrayVTable<IsNullVTable> for IsNullVTable {
6565
fn len(array: &IsNullArray) -> usize {
66-
array.len()
66+
array.child.len()
6767
}
6868

6969
fn dtype(_array: &IsNullArray) -> &DType {

vortex-array/src/compute/arrays/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
pub mod arithmetic;
5+
mod get_item;
56
pub mod is_not_null;
67
pub mod is_null;
78
pub mod logical;

vortex-dtype/src/dtype.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,18 @@ impl DType {
369369
}
370370
}
371371

372+
/// Returns the [`StructFields`] from a struct [`DType`].
373+
///
374+
/// # Panics
375+
///
376+
/// If the [`DType`] is not a struct.
377+
pub fn as_struct_fields(&self) -> &StructFields {
378+
if let Struct(f, _) = self {
379+
return f;
380+
}
381+
vortex_panic!("DType is not a Struct")
382+
}
383+
372384
/// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
373385
pub fn as_struct_fields_opt(&self) -> Option<&StructFields> {
374386
if let Struct(f, _) = self {

vortex-vector/src/struct_/vector.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ impl VectorOps for StructVector {
154154
}
155155
};
156156

157-
// Convert all of the remaining fields to mutable, if possible.
157+
// Convert all the remaining fields to mutable, if possible.
158158
let mut mutable_fields = Vec::with_capacity(fields.len());
159159
let mut fields_iter = fields.into_iter();
160160

0 commit comments

Comments
 (0)