Skip to content

Commit 91624db

Browse files
authored
Add cast expression (#3440)
1 parent 08d730d commit 91624db

File tree

4 files changed

+176
-1
lines changed

4 files changed

+176
-1
lines changed

vortex-expr/src/cast.rs

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
use std::any::Any;
2+
use std::fmt::Display;
3+
use std::sync::Arc;
4+
5+
use vortex_array::compute::cast as compute_cast;
6+
use vortex_array::{Array, ArrayRef};
7+
use vortex_dtype::DType;
8+
use vortex_error::{VortexExpect, VortexResult};
9+
10+
use crate::{ExprRef, VortexExpr};
11+
12+
#[derive(Debug, Eq, Hash)]
13+
#[allow(clippy::derived_hash_with_manual_eq)]
14+
pub struct Cast {
15+
target: DType,
16+
child: ExprRef,
17+
}
18+
19+
impl Cast {
20+
pub fn new_expr(child: ExprRef, target: DType) -> ExprRef {
21+
Arc::new(Self { target, child })
22+
}
23+
}
24+
25+
impl PartialEq for Cast {
26+
fn eq(&self, other: &Self) -> bool {
27+
self.target.eq(&other.target) && self.child.eq(&other.child)
28+
}
29+
}
30+
31+
impl Display for Cast {
32+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33+
write!(f, "cast({}, {})", self.child, self.target)
34+
}
35+
}
36+
37+
#[cfg(feature = "proto")]
38+
pub(crate) mod proto {
39+
use vortex_dtype::DType;
40+
use vortex_error::{VortexResult, vortex_bail, vortex_err};
41+
use vortex_proto::expr::kind;
42+
use vortex_proto::expr::kind::Kind;
43+
44+
use crate::cast::Cast;
45+
use crate::{ExprDeserialize, ExprRef, ExprSerializable, Id};
46+
47+
pub(crate) struct CastSerde;
48+
49+
impl Id for CastSerde {
50+
fn id(&self) -> &'static str {
51+
"cast"
52+
}
53+
}
54+
55+
impl ExprDeserialize for CastSerde {
56+
fn deserialize(&self, kind: &Kind, children: Vec<ExprRef>) -> VortexResult<ExprRef> {
57+
let Kind::Cast(kind::Cast { target }) = kind else {
58+
vortex_bail!("wrong kind {:?}, want cast", kind)
59+
};
60+
let target: DType = target
61+
.as_ref()
62+
.ok_or_else(|| vortex_err!("empty target dtype"))?
63+
.try_into()?;
64+
65+
Ok(Cast::new_expr(children[0].clone(), target))
66+
}
67+
}
68+
69+
impl ExprSerializable for Cast {
70+
fn id(&self) -> &'static str {
71+
CastSerde.id()
72+
}
73+
74+
fn serialize_kind(&self) -> VortexResult<Kind> {
75+
Ok(Kind::Cast(kind::Cast {
76+
target: Some((&self.target).into()),
77+
}))
78+
}
79+
}
80+
}
81+
82+
impl VortexExpr for Cast {
83+
fn as_any(&self) -> &dyn Any {
84+
self
85+
}
86+
87+
fn unchecked_evaluate(&self, batch: &dyn Array) -> VortexResult<ArrayRef> {
88+
let array = self.child.evaluate(batch)?;
89+
compute_cast(&array, &self.target)
90+
}
91+
92+
fn children(&self) -> Vec<&ExprRef> {
93+
vec![&self.child]
94+
}
95+
96+
fn replacing_children(self: Arc<Self>, mut children: Vec<ExprRef>) -> ExprRef {
97+
Self::new_expr(
98+
children
99+
.pop()
100+
.vortex_expect("Cast::replacing_children should have one child"),
101+
self.target.clone(),
102+
)
103+
}
104+
105+
fn return_dtype(&self, _scope_dtype: &DType) -> VortexResult<DType> {
106+
Ok(self.target.clone())
107+
}
108+
}
109+
110+
pub fn cast(child: ExprRef, target: DType) -> ExprRef {
111+
Cast::new_expr(child, target)
112+
}
113+
114+
#[cfg(test)]
115+
mod tests {
116+
use vortex_array::IntoArray;
117+
use vortex_array::arrays::StructArray;
118+
use vortex_buffer::buffer;
119+
use vortex_dtype::{DType, Nullability, PType};
120+
121+
use crate::{ExprRef, cast, get_item, ident, test_harness};
122+
123+
#[test]
124+
fn dtype() {
125+
let dtype = test_harness::struct_dtype();
126+
assert_eq!(
127+
cast(ident(), DType::Bool(Nullability::NonNullable))
128+
.return_dtype(&dtype)
129+
.unwrap(),
130+
DType::Bool(Nullability::NonNullable)
131+
);
132+
}
133+
134+
#[test]
135+
fn replace_children() {
136+
let expr = cast(ident(), DType::Bool(Nullability::Nullable));
137+
let _ = expr.replacing_children(vec![ident()]);
138+
}
139+
140+
#[test]
141+
fn evaluate() {
142+
let test_array = StructArray::from_fields(&[
143+
("a", buffer![0i32, 1, 2].into_array()),
144+
("b", buffer![4i64, 5, 6].into_array()),
145+
])
146+
.unwrap()
147+
.into_array();
148+
149+
let expr: ExprRef = cast(
150+
get_item("a", ident()),
151+
DType::Primitive(PType::I64, Nullability::NonNullable),
152+
);
153+
let result = expr.unchecked_evaluate(&test_array).unwrap();
154+
155+
assert_eq!(
156+
result.dtype(),
157+
&DType::Primitive(PType::I64, Nullability::NonNullable)
158+
);
159+
}
160+
}

vortex-expr/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use dyn_hash::DynHash;
77
mod binary;
88

99
mod between;
10+
mod cast;
1011
mod field;
1112
pub mod forms;
1213
mod get_item;
@@ -27,6 +28,7 @@ pub mod traversal;
2728

2829
pub use between::*;
2930
pub use binary::*;
31+
pub use cast::*;
3032
pub use get_item::*;
3133
pub use identity::*;
3234
pub use is_null::*;

vortex-proto/proto/expr.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ syntax = "proto3";
33
package vortex.expr;
44

55
import "scalar.proto";
6+
import "dtype.proto";
67

78
option java_package = "dev.vortex.proto";
89
option java_outer_classname = "ExprProtos";
@@ -27,6 +28,7 @@ message Kind {
2728
Between between = 8;
2829
Like like = 9;
2930
IsNull is_null = 10;
31+
Cast cast = 11;
3032
}
3133

3234
message Literal {
@@ -77,6 +79,10 @@ message Kind {
7779
message IsNull {
7880

7981
}
82+
83+
message Cast {
84+
vortex.dtype.DType target = 1;
85+
}
8086
}
8187

8288

vortex-proto/src/generated/vortex.expr.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ pub struct Expr {
1212
#[derive(Clone, PartialEq, ::prost::Message)]
1313
pub struct Kind {
1414
/// This enum is very unstable, and will likely be replaced with something more extensible.
15-
#[prost(oneof = "kind::Kind", tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10")]
15+
#[prost(oneof = "kind::Kind", tags = "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11")]
1616
pub kind: ::core::option::Option<kind::Kind>,
1717
}
1818
/// Nested message and enum types in `Kind`.
@@ -54,6 +54,11 @@ pub mod kind {
5454
}
5555
#[derive(Clone, Copy, PartialEq, ::prost::Message)]
5656
pub struct IsNull {}
57+
#[derive(Clone, PartialEq, ::prost::Message)]
58+
pub struct Cast {
59+
#[prost(message, optional, tag = "1")]
60+
pub target: ::core::option::Option<super::super::dtype::DType>,
61+
}
5762
#[derive(
5863
Clone,
5964
Copy,
@@ -131,5 +136,7 @@ pub mod kind {
131136
Like(Like),
132137
#[prost(message, tag = "10")]
133138
IsNull(IsNull),
139+
#[prost(message, tag = "11")]
140+
Cast(Cast),
134141
}
135142
}

0 commit comments

Comments
 (0)