44use vortex_dtype:: DType ;
55use vortex_error:: { VortexResult , vortex_err} ;
66
7- use crate :: expr:: Expression ;
87use crate :: expr:: exprs:: get_item:: get_item;
98use crate :: expr:: exprs:: pack:: pack;
109use crate :: expr:: exprs:: select:: Select ;
10+ use crate :: expr:: transform:: traits:: { ReduceRule , RewriteContext } ;
1111use crate :: expr:: traversal:: { NodeExt , Transformed } ;
12+ use crate :: expr:: { ExprId , Expression } ;
1213
1314/// Replaces [crate::SelectExpr] with combination of [crate::GetItem] and [crate::Pack] expressions.
1415pub ( crate ) fn remove_select ( e : Expression , ctx : & DType ) -> VortexResult < Expression > {
@@ -20,38 +21,87 @@ fn remove_select_transformer(
2021 node : Expression ,
2122 ctx : & DType ,
2223) -> VortexResult < Transformed < Expression > > {
23- match node. as_opt :: < Select > ( ) {
24- None => Ok ( Transformed :: no ( node) ) ,
25- Some ( select) => {
26- let child = select. child ( ) ;
27- let child_dtype = child. return_dtype ( ctx) ?;
28- let child_nullability = child_dtype. nullability ( ) ;
29-
30- let child_dtype = child_dtype. as_struct_fields_opt ( ) . ok_or_else ( || {
31- vortex_err ! (
32- "Select child must return a struct dtype, however it was a {}" ,
33- child_dtype
34- )
35- } ) ?;
36-
37- let expr = pack (
38- select
39- . data ( )
40- . as_include_names ( child_dtype. names ( ) )
41- . map_err ( |e| {
42- e. with_context ( format ! (
43- "Select fields {:?} must be a subset of child fields {:?}" ,
44- select. data( ) ,
45- child_dtype. names( )
46- ) )
47- } ) ?
48- . iter ( )
49- . map ( |name| ( name. clone ( ) , get_item ( name. clone ( ) , child. clone ( ) ) ) ) ,
50- child_nullability,
51- ) ;
52-
53- Ok ( Transformed :: yes ( expr) )
54- }
24+ if let Some ( select) = node. as_opt :: < Select > ( ) {
25+ let child = select. child ( ) ;
26+ let child_dtype = child. return_dtype ( ctx) ?;
27+ let child_nullability = child_dtype. nullability ( ) ;
28+
29+ let child_dtype = child_dtype. as_struct_fields_opt ( ) . ok_or_else ( || {
30+ vortex_err ! (
31+ "Select child must return a struct dtype, however it was a {}" ,
32+ child_dtype
33+ )
34+ } ) ?;
35+
36+ let expr = pack (
37+ select
38+ . data ( )
39+ . as_include_names ( child_dtype. names ( ) )
40+ . map_err ( |e| {
41+ e. with_context ( format ! (
42+ "Select fields {:?} must be a subset of child fields {:?}" ,
43+ select. data( ) ,
44+ child_dtype. names( )
45+ ) )
46+ } ) ?
47+ . iter ( )
48+ . map ( |name| ( name. clone ( ) , get_item ( name. clone ( ) , child. clone ( ) ) ) ) ,
49+ child_nullability,
50+ ) ;
51+
52+ Ok ( Transformed :: yes ( expr) )
53+ } else {
54+ Ok ( Transformed :: no ( node) )
55+ }
56+ }
57+
58+ /// Rule that removes Select expressions by converting them to Pack + GetItem.
59+ ///
60+ /// Transforms: `select(["a", "b"], expr)` → `pack(a: get_item("a", expr), b: get_item("b", expr))`
61+ pub struct RemoveSelectRule ;
62+
63+ impl ReduceRule for RemoveSelectRule {
64+ fn id ( & self ) -> ExprId {
65+ ExprId :: new_ref ( "vortex.select" )
66+ }
67+
68+ fn reduce (
69+ & self ,
70+ expr : & Expression ,
71+ ctx : & dyn RewriteContext ,
72+ ) -> VortexResult < Option < Expression > > {
73+ let Some ( select) = expr. as_opt :: < Select > ( ) else {
74+ return Ok ( None ) ;
75+ } ;
76+
77+ let child = select. child ( ) ;
78+ let child_dtype = child. return_dtype ( ctx. dtype ( ) ) ?;
79+ let child_nullability = child_dtype. nullability ( ) ;
80+
81+ let child_dtype = child_dtype. as_struct_fields_opt ( ) . ok_or_else ( || {
82+ vortex_err ! (
83+ "Select child must return a struct dtype, however it was a {}" ,
84+ child_dtype
85+ )
86+ } ) ?;
87+
88+ let expr = pack (
89+ select
90+ . data ( )
91+ . as_include_names ( child_dtype. names ( ) )
92+ . map_err ( |e| {
93+ e. with_context ( format ! (
94+ "Select fields {:?} must be a subset of child fields {:?}" ,
95+ select. data( ) ,
96+ child_dtype. names( )
97+ ) )
98+ } ) ?
99+ . iter ( )
100+ . map ( |name| ( name. clone ( ) , get_item ( name. clone ( ) , child. clone ( ) ) ) ) ,
101+ child_nullability,
102+ ) ;
103+
104+ Ok ( Some ( expr) )
55105 }
56106}
57107
@@ -61,10 +111,13 @@ mod tests {
61111 use vortex_dtype:: PType :: I32 ;
62112 use vortex_dtype:: { DType , StructFields } ;
63113
64- use super :: remove_select;
114+ use super :: { RemoveSelectRule , remove_select} ;
65115 use crate :: expr:: exprs:: pack:: Pack ;
66116 use crate :: expr:: exprs:: root:: root;
67117 use crate :: expr:: exprs:: select:: select;
118+ use crate :: expr:: session:: ExprSession ;
119+ use crate :: expr:: transform:: simplify_typed:: apply_child_rules;
120+ use crate :: expr:: transform:: traits:: { ReduceRule , SimpleRewriteContext } ;
68121
69122 #[ test]
70123 fn test_remove_select ( ) {
@@ -78,4 +131,50 @@ mod tests {
78131 assert ! ( e. is:: <Pack >( ) ) ;
79132 assert ! ( e. return_dtype( & dtype) . unwrap( ) . is_nullable( ) ) ;
80133 }
134+
135+ #[ test]
136+ fn test_remove_select_rule_direct ( ) {
137+ let dtype = DType :: Struct (
138+ StructFields :: new ( [ "a" , "b" ] . into ( ) , vec ! [ I32 . into( ) , I32 . into( ) ] ) ,
139+ Nullable ,
140+ ) ;
141+ let e = select ( [ "a" , "b" ] , root ( ) ) ;
142+
143+ let rule = RemoveSelectRule ;
144+ let ctx = SimpleRewriteContext { dtype : & dtype } ;
145+ let result = rule. reduce ( & e, & ctx) . unwrap ( ) ;
146+
147+ assert ! ( result. is_some( ) ) ;
148+ let transformed = result. unwrap ( ) ;
149+ assert ! ( transformed. is:: <Pack >( ) ) ;
150+ assert ! ( transformed. return_dtype( & dtype) . unwrap( ) . is_nullable( ) ) ;
151+ }
152+
153+ #[ test]
154+ fn test_remove_select_via_session ( ) {
155+ let dtype = DType :: Struct (
156+ StructFields :: new (
157+ [ "a" , "b" , "c" ] . into ( ) ,
158+ vec ! [ I32 . into( ) , I32 . into( ) , I32 . into( ) ] ,
159+ ) ,
160+ Nullable ,
161+ ) ;
162+
163+ // Create expression: select(["a", "c"], root())
164+ let e = select ( [ "a" , "c" ] , root ( ) ) ;
165+
166+ // Use session which has RemoveSelectRule registered
167+ let session = ExprSession :: default ( ) ;
168+ let result = apply_child_rules ( e, & dtype, & session) . unwrap ( ) ;
169+
170+ // Should be transformed to Pack
171+ assert ! ( result. is:: <Pack >( ) ) ;
172+
173+ // Verify the dtype has only selected fields
174+ let result_dtype = result. return_dtype ( & dtype) . unwrap ( ) ;
175+ let fields = result_dtype. as_struct_fields_opt ( ) . unwrap ( ) ;
176+ assert_eq ! ( fields. names( ) . len( ) , 2 ) ;
177+ assert_eq ! ( fields. names( ) [ 0 ] . as_ref( ) , "a" ) ;
178+ assert_eq ! ( fields. names( ) [ 1 ] . as_ref( ) , "c" ) ;
179+ }
81180}
0 commit comments