Skip to content

Commit 536bca9

Browse files
authored
Add GetItemFn (#5609)
Signed-off-by: Nicholas Gates <[email protected]>
1 parent 51401ca commit 536bca9

File tree

4 files changed

+224
-14
lines changed

4 files changed

+224
-14
lines changed
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright the Vortex contributors
3+
4+
use prost::Message;
5+
use vortex_dtype::DType;
6+
use vortex_dtype::FieldName;
7+
use vortex_dtype::FieldPath;
8+
use vortex_dtype::Nullability;
9+
use vortex_error::VortexResult;
10+
use vortex_error::vortex_err;
11+
use vortex_proto::expr as pb;
12+
use vortex_vector::Datum;
13+
use vortex_vector::ScalarOps;
14+
use vortex_vector::VectorOps;
15+
16+
use crate::expr::Expression;
17+
use crate::expr::StatsCatalog;
18+
use crate::expr::functions::ArgName;
19+
use crate::expr::functions::Arity;
20+
use crate::expr::functions::ExecutionArgs;
21+
use crate::expr::functions::FunctionId;
22+
use crate::expr::functions::VTable;
23+
use crate::expr::stats::Stat;
24+
25+
pub struct GetItemFn;
26+
impl VTable for GetItemFn {
27+
type Options = FieldName;
28+
29+
fn id(&self) -> FunctionId {
30+
FunctionId::from("vortex.get_item")
31+
}
32+
33+
fn serialize(&self, field_name: &FieldName) -> VortexResult<Option<Vec<u8>>> {
34+
Ok(Some(
35+
pb::GetItemOpts {
36+
path: field_name.to_string(),
37+
}
38+
.encode_to_vec(),
39+
))
40+
}
41+
42+
fn deserialize(&self, bytes: &[u8]) -> VortexResult<Self::Options> {
43+
let opts = pb::GetItemOpts::decode(bytes)?;
44+
Ok(FieldName::from(opts.path))
45+
}
46+
47+
fn arity(&self, _field_name: &FieldName) -> Arity {
48+
Arity::Exact(1)
49+
}
50+
51+
fn arg_name(&self, _field_name: &FieldName, _arg_idx: usize) -> ArgName {
52+
ArgName::from("input")
53+
}
54+
55+
fn stat_expression(
56+
&self,
57+
field_name: &FieldName,
58+
_expr: &Expression,
59+
stat: Stat,
60+
catalog: &dyn StatsCatalog,
61+
) -> Option<Expression> {
62+
// TODO(ngates): I think we can do better here and support stats over nested fields.
63+
// It would be nice if delegating to our child would return a struct of statistics
64+
// matching the nested DType such that we can write:
65+
// `get_item(expr.child(0).stat_expression(...), expr.data().field_name())`
66+
67+
// TODO(ngates): this is a bug whereby we may return stats for a nested field of the same
68+
// name as a field in the root struct. This should be resolved with upcoming change to
69+
// falsify expressions, but for now I'm preserving the existing buggy behavior.
70+
catalog.stats_ref(&FieldPath::from_name(field_name.clone()), stat)
71+
}
72+
73+
fn return_dtype(&self, field_name: &FieldName, arg_types: &[DType]) -> VortexResult<DType> {
74+
let struct_dtype = &arg_types[0];
75+
let field_dtype = struct_dtype
76+
.as_struct_fields_opt()
77+
.and_then(|st| st.field(field_name))
78+
.ok_or_else(|| {
79+
vortex_err!("Couldn't find the {} field in the input scope", field_name)
80+
})?;
81+
82+
// Match here to avoid cloning the dtype if nullability doesn't need to change
83+
if matches!(
84+
(struct_dtype.nullability(), field_dtype.nullability()),
85+
(Nullability::Nullable, Nullability::NonNullable)
86+
) {
87+
return Ok(field_dtype.with_nullability(Nullability::Nullable));
88+
}
89+
90+
Ok(field_dtype)
91+
}
92+
93+
fn execute(&self, field_name: &FieldName, args: &ExecutionArgs) -> VortexResult<Datum> {
94+
let struct_dtype = args
95+
.input_type(0)
96+
.as_struct_fields_opt()
97+
.ok_or_else(|| vortex_err!("Expected struct dtype for child of GetItem expression"))?;
98+
let field_idx = struct_dtype
99+
.find(field_name)
100+
.ok_or_else(|| vortex_err!("Field {} not found in struct dtype", field_name))?;
101+
102+
match args.input_datums(0) {
103+
Datum::Scalar(s) => {
104+
let mut field = s.as_struct().field(field_idx);
105+
field.mask_validity(s.is_valid());
106+
Ok(Datum::Scalar(field))
107+
}
108+
Datum::Vector(v) => {
109+
let mut field = v.as_struct().fields()[field_idx].clone();
110+
field.mask_validity(v.validity());
111+
Ok(Datum::Vector(field))
112+
}
113+
}
114+
}
115+
}

vortex-array/src/scalar_fns/mod.rs

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
//! the equivalent Arrow compute function.
1111
1212
use vortex_dtype::DType;
13+
use vortex_dtype::FieldName;
1314
use vortex_error::VortexResult;
1415

1516
use crate::Array;
@@ -20,6 +21,7 @@ use crate::expr::ScalarFnExprExt;
2021
use crate::expr::functions::EmptyOptions;
2122

2223
pub mod cast;
24+
pub mod get_item;
2325
pub mod is_null;
2426
pub mod mask;
2527
pub mod not;
@@ -29,66 +31,80 @@ pub trait ExprBuiltins: Sized {
2931
/// Cast to the given data type.
3032
fn cast(&self, dtype: DType) -> VortexResult<Expression>;
3133

34+
/// Get item by field name (for struct types).
35+
fn get_item(&self, field_name: impl Into<FieldName>) -> VortexResult<Expression>;
36+
3237
/// Is null check.
3338
fn is_null(&self) -> VortexResult<Expression>;
3439

35-
/// Boolean negation.
36-
fn not(&self) -> VortexResult<Expression>;
37-
3840
/// Mask the expression using the given boolean mask.
3941
/// The resulting expression's validity is the intersection of the original expression's
4042
/// validity.
4143
fn mask(&self, mask: Expression) -> VortexResult<Expression>;
44+
45+
/// Boolean negation.
46+
fn not(&self) -> VortexResult<Expression>;
4247
}
4348

4449
impl ExprBuiltins for Expression {
4550
fn cast(&self, dtype: DType) -> VortexResult<Expression> {
4651
cast::CastFn.try_new_expr(dtype, [self.clone()])
4752
}
4853

49-
fn is_null(&self) -> VortexResult<Expression> {
50-
is_null::IsNullFn.try_new_expr(EmptyOptions, [self.clone()])
54+
fn get_item(&self, field_name: impl Into<FieldName>) -> VortexResult<Expression> {
55+
get_item::GetItemFn.try_new_expr(field_name.into(), [self.clone()])
5156
}
5257

53-
fn not(&self) -> VortexResult<Expression> {
54-
not::NotFn.try_new_expr(EmptyOptions, [self.clone()])
58+
fn is_null(&self) -> VortexResult<Expression> {
59+
is_null::IsNullFn.try_new_expr(EmptyOptions, [self.clone()])
5560
}
5661

5762
fn mask(&self, mask: Expression) -> VortexResult<Expression> {
5863
mask::MaskFn.try_new_expr(EmptyOptions, [self.clone(), mask])
5964
}
65+
66+
fn not(&self) -> VortexResult<Expression> {
67+
not::NotFn.try_new_expr(EmptyOptions, [self.clone()])
68+
}
6069
}
6170

6271
pub trait ArrayBuiltins: Sized {
6372
/// Cast to the given data type.
6473
fn cast(&self, dtype: DType) -> VortexResult<ArrayRef>;
6574

75+
/// Get item by field name (for struct types).
76+
fn get_item(&self, field_name: impl Into<FieldName>) -> VortexResult<ArrayRef>;
77+
6678
/// Is null check.
6779
fn is_null(&self) -> VortexResult<ArrayRef>;
6880

69-
/// Boolean negation.
70-
fn not(&self) -> VortexResult<ArrayRef>;
71-
7281
/// Mask the array using the given boolean mask.
7382
/// The resulting array's validity is the intersection of the original array's validity
7483
/// and the mask's validity.
7584
fn mask(&self, mask: &ArrayRef) -> VortexResult<ArrayRef>;
85+
86+
/// Boolean negation.
87+
fn not(&self) -> VortexResult<ArrayRef>;
7688
}
7789

7890
impl ArrayBuiltins for ArrayRef {
7991
fn cast(&self, dtype: DType) -> VortexResult<ArrayRef> {
8092
cast::CastFn.try_new_array(self.len(), dtype, [self.clone()])
8193
}
8294

83-
fn is_null(&self) -> VortexResult<ArrayRef> {
84-
is_null::IsNullFn.try_new_array(self.len(), EmptyOptions, [self.clone()])
95+
fn get_item(&self, field_name: impl Into<FieldName>) -> VortexResult<ArrayRef> {
96+
get_item::GetItemFn.try_new_array(self.len(), field_name.into(), [self.clone()])
8597
}
8698

87-
fn not(&self) -> VortexResult<ArrayRef> {
88-
not::NotFn.try_new_array(self.len(), EmptyOptions, [self.clone()])
99+
fn is_null(&self) -> VortexResult<ArrayRef> {
100+
is_null::IsNullFn.try_new_array(self.len(), EmptyOptions, [self.clone()])
89101
}
90102

91103
fn mask(&self, mask: &ArrayRef) -> VortexResult<ArrayRef> {
92104
mask::MaskFn.try_new_array(self.len(), EmptyOptions, [self.clone(), mask.clone()])
93105
}
106+
107+
fn not(&self) -> VortexResult<ArrayRef> {
108+
not::NotFn.try_new_array(self.len(), EmptyOptions, [self.clone()])
109+
}
94110
}

vortex-vector/src/scalar.rs

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,77 @@ impl Scalar {
130130
vortex_panic!("Cannot convert non-struct Scalar into StructScalar");
131131
}
132132
}
133+
134+
impl Scalar {
135+
/// Converts the `Scalar` into a `NullScalar`.
136+
pub fn as_null(&self) -> &NullScalar {
137+
if let Scalar::Null(scalar) = self {
138+
return scalar;
139+
}
140+
vortex_panic!("Cannot convert non-null Scalar into NullScalar");
141+
}
142+
143+
/// Converts the `Scalar` into a `BoolScalar`.
144+
pub fn as_bool(&self) -> &BoolScalar {
145+
if let Scalar::Bool(scalar) = self {
146+
return scalar;
147+
}
148+
vortex_panic!("Cannot convert non-bool Scalar into BoolScalar");
149+
}
150+
151+
/// Converts the `Scalar` into a `DecimalScalar`.
152+
pub fn as_decimal(&self) -> &DecimalScalar {
153+
if let Scalar::Decimal(scalar) = self {
154+
return scalar;
155+
}
156+
vortex_panic!("Cannot convert non-decimal Scalar into DecimalScalar");
157+
}
158+
159+
/// Converts the `Scalar` into a `PrimitiveScalar`.
160+
pub fn as_primitive(&self) -> &PrimitiveScalar {
161+
if let Scalar::Primitive(scalar) = self {
162+
return scalar;
163+
}
164+
vortex_panic!("Cannot convert non-primitive Scalar into PrimitiveScalar");
165+
}
166+
167+
/// Converts the `Scalar` into a `StringScalar`.
168+
pub fn as_string(&self) -> &StringScalar {
169+
if let Scalar::String(scalar) = self {
170+
return scalar;
171+
}
172+
vortex_panic!("Cannot convert non-string Scalar into StringScalar");
173+
}
174+
175+
/// Converts the `Scalar` into a `BinaryScalar`.
176+
pub fn as_binary(&self) -> &BinaryScalar {
177+
if let Scalar::Binary(scalar) = self {
178+
return scalar;
179+
}
180+
vortex_panic!("Cannot convert non-binary Scalar into BinaryScalar");
181+
}
182+
183+
/// Converts the `Scalar` into a `ListViewScalar`.
184+
pub fn as_list(&self) -> &ListViewScalar {
185+
if let Scalar::List(scalar) = self {
186+
return scalar;
187+
}
188+
vortex_panic!("Cannot convert non-list Scalar into ListViewScalar");
189+
}
190+
191+
/// Converts the `Scalar` into a `FixedSizeListScalar`.
192+
pub fn as_fixed_size_list(&self) -> &FixedSizeListScalar {
193+
if let Scalar::FixedSizeList(scalar) = self {
194+
return scalar;
195+
}
196+
vortex_panic!("Cannot convert non-fixed-size-list Scalar into FixedSizeListScalar");
197+
}
198+
199+
/// Converts the `Scalar` into a `StructScalar`.
200+
pub fn as_struct(&self) -> &StructScalar {
201+
if let Scalar::Struct(scalar) = self {
202+
return scalar;
203+
}
204+
vortex_panic!("Cannot convert non-struct Scalar into StructScalar");
205+
}
206+
}

vortex-vector/src/struct_/scalar.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ impl StructScalar {
2929
pub fn value(&self) -> &StructVector {
3030
&self.0
3131
}
32+
33+
/// Returns the nth field scalar of the struct.
34+
pub fn field(&self, field_idx: usize) -> Scalar {
35+
self.0.fields()[field_idx].scalar_at(0)
36+
}
3237
}
3338

3439
impl ScalarOps for StructScalar {

0 commit comments

Comments
 (0)