@@ -208,6 +208,7 @@ mod tests {
208208 use crate :: expr:: exprs:: pack:: pack;
209209 use crate :: expr:: exprs:: root:: root;
210210 use crate :: expr:: exprs:: select:: select;
211+ use crate :: expr:: session:: ExprSession ;
211212 use crate :: expr:: transform:: immediate_access:: annotate_scope_access;
212213 use crate :: expr:: transform:: replace:: replace_root_fields;
213214 use crate :: expr:: transform:: simplify_typed:: simplify_typed;
@@ -233,13 +234,14 @@ mod tests {
233234 #[ rstest]
234235 fn test_expr_top_level_ref ( dtype : DType ) {
235236 let fields = dtype. as_struct_fields_opt ( ) . unwrap ( ) ;
237+ let optimizer = ExprOptimizer :: new ( ExprSession :: default ( ) ) ;
236238
237239 let expr = root ( ) ;
238240 let partitioned = partition (
239241 expr. clone ( ) ,
240242 & dtype,
241243 annotate_scope_access ( fields) ,
242- & ExprSession :: default ( ) ,
244+ & optimizer ,
243245 )
244246 . unwrap ( ) ;
245247
@@ -249,36 +251,28 @@ mod tests {
249251
250252 // Instead, callers must expand the root expression themselves.
251253 let expr = replace_root_fields ( expr, fields) ;
252- let partitioned = partition (
253- expr,
254- & dtype,
255- annotate_scope_access ( fields) ,
256- & ExprSession :: default ( ) ,
257- )
258- . unwrap ( ) ;
254+ let partitioned =
255+ partition ( expr, & dtype, annotate_scope_access ( fields) , & optimizer) . unwrap ( ) ;
259256
260257 assert_eq ! ( partitioned. partitions. len( ) , fields. names( ) . len( ) ) ;
261258 }
262259
263260 #[ rstest]
264261 fn test_expr_top_level_ref_get_item_and_split ( dtype : DType ) {
265262 let fields = dtype. as_struct_fields_opt ( ) . unwrap ( ) ;
263+ let optimizer = ExprOptimizer :: new ( ExprSession :: default ( ) ) ;
266264
267265 let expr = get_item ( "y" , get_item ( "a" , root ( ) ) ) ;
268266
269- let partitioned = partition (
270- expr,
271- & dtype,
272- annotate_scope_access ( fields) ,
273- & ExprSession :: default ( ) ,
274- )
275- . unwrap ( ) ;
267+ let partitioned =
268+ partition ( expr, & dtype, annotate_scope_access ( fields) , & optimizer) . unwrap ( ) ;
276269 assert_eq ! ( & partitioned. root, & get_item( "a_0" , get_item( "a" , root( ) ) ) ) ;
277270 }
278271
279272 #[ rstest]
280273 fn test_expr_top_level_ref_get_item_and_split_pack ( dtype : DType ) {
281274 let fields = dtype. as_struct_fields_opt ( ) . unwrap ( ) ;
275+ let optimizer = ExprOptimizer :: new ( ExprSession :: default ( ) ) ;
282276
283277 let expr = pack (
284278 [
@@ -288,17 +282,12 @@ mod tests {
288282 ] ,
289283 NonNullable ,
290284 ) ;
291- let partitioned = partition (
292- expr,
293- & dtype,
294- annotate_scope_access ( fields) ,
295- & ExprSession :: default ( ) ,
296- )
297- . unwrap ( ) ;
285+ let partitioned =
286+ partition ( expr, & dtype, annotate_scope_access ( fields) , & optimizer) . unwrap ( ) ;
298287
299288 let split_a = partitioned. find_partition ( & "a" . into ( ) ) . unwrap ( ) ;
300289 assert_eq ! (
301- & simplify_typed( split_a. clone( ) , & dtype, & ExprSession :: default ( ) , ) . unwrap( ) ,
290+ & simplify_typed( split_a. clone( ) , & dtype, & ExprSession :: default ( ) ) . unwrap( ) ,
302291 & pack(
303292 [
304293 ( "a_0" , get_item( "x" , get_item( "a" , root( ) ) ) ) ,
@@ -312,15 +301,11 @@ mod tests {
312301 #[ rstest]
313302 fn test_expr_top_level_ref_get_item_add ( dtype : DType ) {
314303 let fields = dtype. as_struct_fields_opt ( ) . unwrap ( ) ;
304+ let optimizer = ExprOptimizer :: new ( ExprSession :: default ( ) ) ;
315305
316306 let expr = and ( get_item ( "y" , get_item ( "a" , root ( ) ) ) , lit ( 1 ) ) ;
317- let partitioned = partition (
318- expr,
319- & dtype,
320- annotate_scope_access ( fields) ,
321- & ExprSession :: default ( ) ,
322- )
323- . unwrap ( ) ;
307+ let partitioned =
308+ partition ( expr, & dtype, annotate_scope_access ( fields) , & optimizer) . unwrap ( ) ;
324309
325310 // Whole expr is a single split
326311 assert_eq ! ( partitioned. partitions. len( ) , 1 ) ;
@@ -329,15 +314,11 @@ mod tests {
329314 #[ rstest]
330315 fn test_expr_top_level_ref_get_item_add_cannot_split ( dtype : DType ) {
331316 let fields = dtype. as_struct_fields_opt ( ) . unwrap ( ) ;
317+ let optimizer = ExprOptimizer :: new ( ExprSession :: default ( ) ) ;
332318
333319 let expr = and ( get_item ( "y" , get_item ( "a" , root ( ) ) ) , get_item ( "b" , root ( ) ) ) ;
334- let partitioned = partition (
335- expr,
336- & dtype,
337- annotate_scope_access ( fields) ,
338- & ExprSession :: default ( ) ,
339- )
340- . unwrap ( ) ;
320+ let partitioned =
321+ partition ( expr, & dtype, annotate_scope_access ( fields) , & optimizer) . unwrap ( ) ;
341322
342323 // One for id.a and id.b
343324 assert_eq ! ( partitioned. partitions. len( ) , 2 ) ;
@@ -347,19 +328,15 @@ mod tests {
347328 #[ rstest]
348329 fn test_expr_partition_many_occurrences_of_field ( dtype : DType ) {
349330 let fields = dtype. as_struct_fields_opt ( ) . unwrap ( ) ;
331+ let optimizer = ExprOptimizer :: new ( ExprSession :: default ( ) ) ;
350332
351333 let expr = and (
352334 get_item ( "y" , get_item ( "a" , root ( ) ) ) ,
353335 select ( [ "a" , "b" ] , root ( ) ) ,
354336 ) ;
355337 let expr = simplify_typed ( expr, & dtype, & ExprSession :: default ( ) ) . unwrap ( ) ;
356- let partitioned = partition (
357- expr,
358- & dtype,
359- annotate_scope_access ( fields) ,
360- & ExprSession :: default ( ) ,
361- )
362- . unwrap ( ) ;
338+ let partitioned =
339+ partition ( expr, & dtype, annotate_scope_access ( fields) , & optimizer) . unwrap ( ) ;
363340
364341 // One for id.a and id.b
365342 assert_eq ! ( partitioned. partitions. len( ) , 2 ) ;
@@ -393,16 +370,12 @@ mod tests {
393370 #[ rstest]
394371 fn test_expr_merge ( dtype : DType ) {
395372 let fields = dtype. as_struct_fields_opt ( ) . unwrap ( ) ;
373+ let optimizer = ExprOptimizer :: new ( ExprSession :: default ( ) ) ;
396374
397375 let expr = merge ( [ col ( "a" ) , pack ( [ ( "b" , col ( "b" ) ) ] , NonNullable ) ] ) ;
398376
399- let partitioned = partition (
400- expr,
401- & dtype,
402- annotate_scope_access ( fields) ,
403- & ExprSession :: default ( ) ,
404- )
405- . unwrap ( ) ;
377+ let partitioned =
378+ partition ( expr, & dtype, annotate_scope_access ( fields) , & optimizer) . unwrap ( ) ;
406379 let expected = pack (
407380 [
408381 ( "x" , get_item ( "x" , get_item ( "a_0" , col ( "a" ) ) ) ) ,
0 commit comments