Skip to content

Commit 7910345

Browse files
committed
GetItemArray
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 846f65f commit 7910345

File tree

6 files changed

+335
-4
lines changed

6 files changed

+335
-4
lines changed
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use std::hash::{Hash, Hasher};
5+
6+
use vortex_compute::mask::MaskValidity;
7+
use vortex_dtype::{DType, FieldName};
8+
use vortex_error::{VortexResult, vortex_bail, vortex_err};
9+
use vortex_vector::VectorOps;
10+
11+
use crate::execution::{BatchKernelRef, BindCtx, kernel};
12+
use crate::stats::{ArrayStats, StatsSetRef};
13+
use crate::vtable::{ArrayVTable, NotSupported, OperatorVTable, VTable, VisitorVTable};
14+
use crate::{
15+
Array, ArrayBufferVisitor, ArrayChildVisitor, ArrayEq, ArrayHash, ArrayRef, EncodingId,
16+
EncodingRef, Precision, vtable,
17+
};
18+
19+
vtable!(GetItem);
20+
21+
/// An array that extracts the given field from a Struct array.
22+
///
23+
/// The validity of the field is intersected with the validity of the parent Struct array.
24+
#[derive(Debug, Clone)]
25+
pub struct GetItemArray {
26+
child: ArrayRef,
27+
field: FieldName,
28+
dtype: DType,
29+
stats: ArrayStats,
30+
}
31+
32+
impl GetItemArray {
33+
/// Create a new get_item array.
34+
pub fn try_new(child: ArrayRef, field: FieldName) -> VortexResult<Self> {
35+
let DType::Struct(fields, _) = child.dtype() else {
36+
vortex_bail!(
37+
"GetItem can only be applied to Struct arrays, got {}",
38+
child.dtype()
39+
);
40+
};
41+
42+
let Some(dtype) = fields.field(&field) else {
43+
vortex_bail!("Field '{}' does not exist in Struct array", field);
44+
};
45+
46+
// Make the field nullable if the parent struct is nullable
47+
let dtype = dtype.with_nullability(dtype.nullability() | child.dtype().nullability());
48+
49+
Ok(Self {
50+
child,
51+
field,
52+
dtype,
53+
stats: ArrayStats::default(),
54+
})
55+
}
56+
}
57+
58+
#[derive(Debug, Clone)]
59+
pub struct GetItemEncoding;
60+
61+
impl VTable for GetItemVTable {
62+
type Array = GetItemArray;
63+
type Encoding = GetItemEncoding;
64+
type ArrayVTable = Self;
65+
type CanonicalVTable = NotSupported;
66+
type OperationsVTable = NotSupported;
67+
type ValidityVTable = NotSupported;
68+
type VisitorVTable = Self;
69+
type ComputeVTable = NotSupported;
70+
type EncodeVTable = NotSupported;
71+
type SerdeVTable = NotSupported;
72+
type OperatorVTable = Self;
73+
74+
fn id(_encoding: &Self::Encoding) -> EncodingId {
75+
EncodingId::from("vortex.get_item")
76+
}
77+
78+
fn encoding(_array: &Self::Array) -> EncodingRef {
79+
EncodingRef::from(GetItemEncoding.as_ref())
80+
}
81+
}
82+
83+
impl ArrayVTable<GetItemVTable> for GetItemVTable {
84+
fn len(array: &GetItemArray) -> usize {
85+
array.child.len()
86+
}
87+
88+
fn dtype(array: &GetItemArray) -> &DType {
89+
&array.dtype
90+
}
91+
92+
fn stats(array: &GetItemArray) -> StatsSetRef<'_> {
93+
array.stats.to_ref(array.as_ref())
94+
}
95+
96+
fn array_hash<H: Hasher>(array: &GetItemArray, state: &mut H, precision: Precision) {
97+
array.child.array_hash(state, precision);
98+
array.field.hash(state);
99+
}
100+
101+
fn array_eq(array: &GetItemArray, other: &GetItemArray, precision: Precision) -> bool {
102+
array.child.array_eq(&other.child, precision) && array.field == other.field
103+
}
104+
}
105+
106+
impl VisitorVTable<GetItemVTable> for GetItemVTable {
107+
fn visit_buffers(_array: &GetItemArray, _visitor: &mut dyn ArrayBufferVisitor) {
108+
// No buffers
109+
}
110+
111+
fn visit_children(array: &GetItemArray, visitor: &mut dyn ArrayChildVisitor) {
112+
visitor.visit_child("struct", array.child.as_ref());
113+
}
114+
}
115+
116+
impl OperatorVTable<GetItemVTable> for GetItemVTable {
117+
fn bind(
118+
array: &GetItemArray,
119+
selection: Option<&ArrayRef>,
120+
ctx: &mut dyn BindCtx,
121+
) -> VortexResult<BatchKernelRef> {
122+
let child = ctx.bind(&array.child, selection)?;
123+
124+
// Find the index of the field in the struct
125+
let idx = array
126+
.child
127+
.dtype()
128+
.as_struct_fields()
129+
.find(&array.field)
130+
.ok_or_else(|| vortex_err!("Field '{}' does not exist in Struct array", array.field))?;
131+
132+
Ok(kernel(move || {
133+
let struct_ = child.execute()?.into_struct();
134+
135+
// We must intersect the validity with that of the parent struct
136+
let field = struct_.fields()[idx].clone();
137+
let field = MaskValidity::mask_validity(field, &struct_.validity());
138+
139+
Ok(field)
140+
}))
141+
}
142+
}
143+
144+
#[cfg(test)]
145+
mod tests {
146+
use vortex_buffer::{bitbuffer, buffer};
147+
use vortex_dtype::{FieldNames, Nullability, PTypeDowncast};
148+
use vortex_vector::VectorOps;
149+
150+
use crate::arrays::{BoolArray, PrimitiveArray, StructArray};
151+
use crate::compute::arrays::get_item::GetItemArray;
152+
use crate::validity::Validity;
153+
use crate::{ArrayOperator, IntoArray};
154+
155+
#[test]
156+
fn test_get_item_basic() {
157+
// Create a non-nullable struct with non-nullable fields
158+
let int_field = PrimitiveArray::from_iter([10i32, 20, 30, 40]);
159+
let bool_field = BoolArray::from_iter([true, false, true, false]);
160+
161+
let struct_array = StructArray::try_new(
162+
FieldNames::from(["numbers", "flags"]),
163+
vec![int_field.into_array(), bool_field.into_array()],
164+
4,
165+
Validity::NonNullable,
166+
)
167+
.unwrap()
168+
.into_array();
169+
170+
// Extract the "numbers" field
171+
let get_item = GetItemArray::try_new(struct_array, "numbers".into())
172+
.unwrap()
173+
.into_array();
174+
175+
// Verify the dtype is non-nullable
176+
assert_eq!(get_item.dtype().nullability(), Nullability::NonNullable);
177+
178+
// Execute and verify the values
179+
let result = get_item.execute().unwrap().into_primitive().into_i32();
180+
assert_eq!(result.elements(), &buffer![10i32, 20, 30, 40]);
181+
}
182+
183+
#[test]
184+
fn test_get_item_nullable_struct_nonnullable_field() {
185+
// Create a nullable struct with non-nullable field
186+
// The result should be nullable because the struct is nullable
187+
let int_field = PrimitiveArray::from_iter([10i32, 20, 30, 40]);
188+
189+
let struct_array = StructArray::try_new(
190+
FieldNames::from(["numbers"]),
191+
vec![int_field.into_array()],
192+
4,
193+
Validity::from_iter([true, false, true, false]),
194+
)
195+
.unwrap()
196+
.into_array();
197+
198+
// Extract the "numbers" field
199+
let get_item = GetItemArray::try_new(struct_array, "numbers".into())
200+
.unwrap()
201+
.into_array();
202+
203+
// The dtype should be nullable even though the field itself is non-nullable
204+
assert_eq!(get_item.dtype().nullability(), Nullability::Nullable);
205+
206+
// Execute and verify values and validity
207+
let result = get_item.execute().unwrap().into_primitive().into_i32();
208+
assert_eq!(result.elements(), &buffer![10i32, 20, 30, 40]);
209+
210+
// Check that validity was properly intersected
211+
// Elements at indices 1 and 3 should be null due to struct validity
212+
assert_eq!(result.get(0), Some(&10));
213+
assert_eq!(result.get(1), None); // Null from struct
214+
assert_eq!(result.get(2), Some(&30));
215+
assert_eq!(result.get(3), None); // Null from struct
216+
}
217+
218+
#[test]
219+
fn test_get_item_with_selection() {
220+
// Create a struct with multiple fields
221+
let int_field = PrimitiveArray::from_iter([10i32, 20, 30, 40, 50, 60]);
222+
let bool_field = BoolArray::from_iter([true, false, true, false, true, false]);
223+
224+
let struct_array = StructArray::try_new(
225+
FieldNames::from(["numbers", "flags"]),
226+
vec![int_field.into_array(), bool_field.into_array()],
227+
6,
228+
Validity::from_iter([true, true, false, true, true, false]),
229+
)
230+
.unwrap()
231+
.into_array();
232+
233+
// Extract the "numbers" field
234+
let get_item = GetItemArray::try_new(struct_array, "numbers".into())
235+
.unwrap()
236+
.into_array();
237+
238+
// Apply selection mask [1 0 1 0 1 0] => select indices 0, 2, 4
239+
let selection = bitbuffer![1 0 1 0 1 0].into_array();
240+
let result = get_item
241+
.execute_with_selection(Some(&selection))
242+
.unwrap()
243+
.into_primitive()
244+
.into_i32();
245+
246+
// Should have 3 elements: indices 0, 2, 4
247+
assert_eq!(result.len(), 3);
248+
assert_eq!(result.elements(), &buffer![10i32, 30, 50]);
249+
250+
// Check validity: index 0 is valid, index 2 is null (struct), index 4 is valid
251+
assert_eq!(result.get(0), Some(&10));
252+
assert_eq!(result.get(1), None); // Index 2 of original was null in struct
253+
assert_eq!(result.get(2), Some(&50));
254+
}
255+
256+
#[test]
257+
fn test_get_item_intersects_validity() {
258+
// Test that field validity is intersected with struct validity
259+
// Field has nulls at indices 1, 3
260+
let int_field =
261+
PrimitiveArray::from_option_iter([Some(10i32), None, Some(30), None, Some(50)]);
262+
263+
// Struct has nulls at indices 2, 4
264+
let struct_array = StructArray::try_new(
265+
FieldNames::from(["values"]),
266+
vec![int_field.into_array()],
267+
5,
268+
Validity::from_iter([true, true, false, true, false]),
269+
)
270+
.unwrap()
271+
.into_array();
272+
273+
let get_item = GetItemArray::try_new(struct_array, "values".into())
274+
.unwrap()
275+
.into_array();
276+
277+
let result = get_item.execute().unwrap().into_primitive().into_i32();
278+
279+
// Verify that nulls are correctly combined:
280+
// Index 0: valid (both valid)
281+
// Index 1: null (field null)
282+
// Index 2: null (struct null)
283+
// Index 3: null (field null)
284+
// Index 4: null (struct null)
285+
assert_eq!(result.get(0), Some(&10));
286+
assert_eq!(result.get(1), None);
287+
assert_eq!(result.get(2), None);
288+
assert_eq!(result.get(3), None);
289+
assert_eq!(result.get(4), None);
290+
}
291+
292+
#[test]
293+
fn test_get_item_bool_field() {
294+
// Test extracting a boolean field
295+
let bool_field = BoolArray::from_iter([true, false, true, false]);
296+
297+
let struct_array = StructArray::try_new(
298+
FieldNames::from(["flags"]),
299+
vec![bool_field.into_array()],
300+
4,
301+
Validity::from_iter([true, false, true, true]),
302+
)
303+
.unwrap()
304+
.into_array();
305+
306+
let get_item = GetItemArray::try_new(struct_array, "flags".into())
307+
.unwrap()
308+
.into_array();
309+
310+
let result = get_item.execute().unwrap().into_bool();
311+
312+
// Verify values
313+
assert_eq!(result.bits(), &bitbuffer![1 0 1 0]);
314+
315+
// Verify validity (index 1 should be null from struct)
316+
assert_eq!(result.validity().to_bit_buffer(), bitbuffer![1 0 1 1]);
317+
}
318+
}

vortex-array/src/compute/arrays/is_not_null.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ impl VTable for IsNotNullVTable {
5252
type OperatorVTable = Self;
5353

5454
fn id(_encoding: &Self::Encoding) -> EncodingId {
55-
EncodingId::from("vortex.is_null")
55+
EncodingId::from("vortex.is_not_null")
5656
}
5757

5858
fn encoding(_array: &Self::Array) -> EncodingRef {
@@ -62,7 +62,7 @@ impl VTable for IsNotNullVTable {
6262

6363
impl ArrayVTable<IsNotNullVTable> for IsNotNullVTable {
6464
fn len(array: &IsNotNullArray) -> usize {
65-
array.len()
65+
array.child.len()
6666
}
6767

6868
fn dtype(_array: &IsNotNullArray) -> &DType {

vortex-array/src/compute/arrays/is_null.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ impl VTable for IsNullVTable {
6363

6464
impl ArrayVTable<IsNullVTable> for IsNullVTable {
6565
fn len(array: &IsNullArray) -> usize {
66-
array.len()
66+
array.child.len()
6767
}
6868

6969
fn dtype(_array: &IsNullArray) -> &DType {

vortex-array/src/compute/arrays/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright the Vortex contributors
33

44
pub mod arithmetic;
5+
mod get_item;
56
pub mod is_not_null;
67
pub mod is_null;
78
pub mod logical;

vortex-dtype/src/dtype.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,18 @@ impl DType {
369369
}
370370
}
371371

372+
/// Returns the [`StructFields`] from a struct [`DType`].
373+
///
374+
/// # Panics
375+
///
376+
/// If the [`DType`] is not a struct.
377+
pub fn as_struct_fields(&self) -> &StructFields {
378+
if let Struct(f, _) = self {
379+
return f;
380+
}
381+
vortex_panic!("DType is not a Struct")
382+
}
383+
372384
/// Get the `StructDType` if `self` is a `StructDType`, otherwise `None`
373385
pub fn as_struct_fields_opt(&self) -> Option<&StructFields> {
374386
if let Struct(f, _) = self {

vortex-vector/src/struct_/vector.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ impl VectorOps for StructVector {
154154
}
155155
};
156156

157-
// Convert all of the remaining fields to mutable, if possible.
157+
// Convert all the remaining fields to mutable, if possible.
158158
let mut mutable_fields = Vec::with_capacity(fields.len());
159159
let mut fields_iter = fields.into_iter();
160160

0 commit comments

Comments
 (0)