@@ -56,8 +56,10 @@ impl IRFunctionExpr {
5656 #[ cfg( feature = "trigonometry" ) ]
5757 Atan2 => mapper. map_to_float_dtype ( ) ,
5858 #[ cfg( feature = "sign" ) ]
59- Sign => mapper. ensure_satisfies ( |_, dtype| dtype. is_primitive_numeric ( ) , "sign" ) ?. with_same_dtype ( ) ,
60- FillNull => mapper. map_to_supertype ( ) ,
59+ Sign => mapper
60+ . ensure_satisfies ( |_, dtype| dtype. is_primitive_numeric ( ) , "sign" ) ?
61+ . with_same_dtype ( ) ,
62+ FillNull => mapper. map_to_supertype ( ) ,
6163 #[ cfg( feature = "rolling_window" ) ]
6264 RollingExpr { function, options } => {
6365 use IRRollingFunction :: * ;
@@ -67,7 +69,7 @@ impl IRFunctionExpr {
6769 Var => mapper. var_dtype ( ) ,
6870 Sum => mapper. sum_dtype ( ) ,
6971 #[ cfg( feature = "cov" ) ]
70- CorrCov { .. } => mapper. map_to_float_dtype ( ) ,
72+ CorrCov { .. } => mapper. map_to_float_dtype ( ) ,
7173 #[ cfg( feature = "moment" ) ]
7274 Skew | Kurtosis => mapper. map_to_float_dtype ( ) ,
7375 Map ( _) => mapper. try_map_field ( |field| {
@@ -84,20 +86,22 @@ impl IRFunctionExpr {
8486 }
8587 } ,
8688 #[ cfg( feature = "rolling_window_by" ) ]
87- RollingExprBy { function_by, ..} => {
89+ RollingExprBy { function_by, .. } => {
8890 use IRRollingFunctionBy :: * ;
8991 match function_by {
9092 MinBy | MaxBy => mapper. with_same_dtype ( ) ,
91- MeanBy | QuantileBy | StdBy => mapper. moment_dtype ( ) ,
93+ MeanBy | QuantileBy | StdBy => mapper. moment_dtype ( ) ,
9294 VarBy => mapper. var_dtype ( ) ,
9395 SumBy => mapper. sum_dtype ( ) ,
9496 }
9597 } ,
9698 Rechunk => mapper. with_same_dtype ( ) ,
97- Append { upcast } => if * upcast {
98- mapper. map_to_supertype ( )
99- } else {
100- mapper. with_same_dtype ( )
99+ Append { upcast } => {
100+ if * upcast {
101+ mapper. map_to_supertype ( )
102+ } else {
103+ mapper. with_same_dtype ( )
104+ }
101105 } ,
102106 ShiftAndFill => mapper. with_same_dtype ( ) ,
103107 DropNans => mapper. with_same_dtype ( ) ,
@@ -131,10 +135,16 @@ impl IRFunctionExpr {
131135 #[ cfg( feature = "dtype-struct" ) ]
132136 AsStruct => {
133137 let mut field_names = PlHashSet :: with_capacity ( fields. len ( ) - 1 ) ;
134- let struct_fields = fields. iter ( ) . map ( |f| {
135- polars_ensure ! ( field_names. insert( f. name. as_str( ) ) , duplicate_field = f. name( ) ) ;
136- Ok ( f. clone ( ) )
137- } ) . collect :: < PolarsResult < Vec < _ > > > ( ) ?;
138+ let struct_fields = fields
139+ . iter ( )
140+ . map ( |f| {
141+ polars_ensure ! (
142+ field_names. insert( f. name. as_str( ) ) ,
143+ duplicate_field = f. name( )
144+ ) ;
145+ Ok ( f. clone ( ) )
146+ } )
147+ . collect :: < PolarsResult < Vec < _ > > > ( ) ?;
138148 Ok ( Field :: new (
139149 fields[ 0 ] . name ( ) . clone ( ) ,
140150 DataType :: Struct ( struct_fields) ,
@@ -265,41 +275,14 @@ impl IRFunctionExpr {
265275 RepeatBy => mapper. map_dtype ( |dt| DataType :: List ( dt. clone ( ) . into ( ) ) ) ,
266276 #[ cfg( feature = "dtype-array" ) ]
267277 Reshape ( dims) => mapper. try_map_dtype ( |dt : & DataType | {
268- let dtype = dt. inner_dtype ( ) . unwrap_or ( dt) . clone ( ) ;
269-
270- if dims. len ( ) == 1 {
271- return Ok ( dtype) ;
272- }
273-
274- let num_infers = dims. iter ( ) . filter ( |d| matches ! ( d, ReshapeDimension :: Infer ) ) . count ( ) ;
275-
276- polars_ensure ! ( num_infers <= 1 , InvalidOperation : "can only specify one inferred dimension" ) ;
277-
278- let mut inferred_size = 0 ;
279- if num_infers == 1 {
280- let mut total_size = 1u64 ;
281- let mut current = dt;
282- while let DataType :: Array ( dt, width) = current {
283- if * width == 0 {
284- total_size = 0 ;
285- break ;
286- }
287-
288- current = dt. as_ref ( ) ;
289- total_size *= * width as u64 ;
290- }
291-
292- let current_size = dims. iter ( ) . map ( |d| d. get_or_infer ( 1 ) ) . product :: < u64 > ( ) ;
293- inferred_size = total_size / current_size;
294- }
295-
296- let mut prev_dtype = dtype. leaf_dtype ( ) . clone ( ) ;
297-
298- // We pop the outer dimension as that is the height of the series.
299- for dim in & dims[ 1 ..] {
300- prev_dtype = DataType :: Array ( Box :: new ( prev_dtype) , dim. get_or_infer ( inferred_size) as usize ) ;
278+ let mut wrapped_dtype = dt. leaf_dtype ( ) . clone ( ) ;
279+ for dim in dims[ 1 ..] . iter ( ) . rev ( ) {
280+ let Some ( array_size) = dim. get ( ) else {
281+ polars_bail ! ( InvalidOperation : "can only infer the first dimension" ) ;
282+ } ;
283+ wrapped_dtype = DataType :: Array ( Box :: new ( wrapped_dtype) , array_size as usize ) ;
301284 }
302- Ok ( prev_dtype )
285+ Ok ( wrapped_dtype )
303286 } ) ,
304287 #[ cfg( feature = "cutqcut" ) ]
305288 QCut {
@@ -350,37 +333,56 @@ impl IRFunctionExpr {
350333 Some ( dtype) => mapper. with_dtype ( dtype. clone ( ) ) ,
351334 } ,
352335 #[ cfg( feature = "dtype-struct" ) ]
353- CumReduceHorizontal {
354- return_dtype, ..
355- } => match return_dtype {
336+ CumReduceHorizontal { return_dtype, .. } => match return_dtype {
356337 None => mapper. with_dtype ( DataType :: Struct ( fields. to_vec ( ) ) ) ,
357- Some ( dtype) => mapper. with_dtype ( DataType :: Struct ( fields. iter ( ) . map ( |f| Field :: new ( f. name ( ) . clone ( ) , dtype. clone ( ) ) ) . collect ( ) ) ) ,
338+ Some ( dtype) => mapper. with_dtype ( DataType :: Struct (
339+ fields
340+ . iter ( )
341+ . map ( |f| Field :: new ( f. name ( ) . clone ( ) , dtype. clone ( ) ) )
342+ . collect ( ) ,
343+ ) ) ,
358344 } ,
359345 #[ cfg( feature = "dtype-struct" ) ]
360- CumFoldHorizontal { return_dtype, include_init, .. } => match return_dtype {
361- None => mapper. with_dtype ( DataType :: Struct ( fields. iter ( ) . skip ( usize:: from ( !include_init) ) . map ( |f| Field :: new ( f. name ( ) . clone ( ) , fields[ 0 ] . dtype ( ) . clone ( ) ) ) . collect ( ) ) ) ,
362- Some ( dtype) => mapper. with_dtype ( DataType :: Struct ( fields. iter ( ) . skip ( usize:: from ( !include_init) ) . map ( |f| Field :: new ( f. name ( ) . clone ( ) , dtype. clone ( ) ) ) . collect ( ) ) ) ,
346+ CumFoldHorizontal {
347+ return_dtype,
348+ include_init,
349+ ..
350+ } => match return_dtype {
351+ None => mapper. with_dtype ( DataType :: Struct (
352+ fields
353+ . iter ( )
354+ . skip ( usize:: from ( !include_init) )
355+ . map ( |f| Field :: new ( f. name ( ) . clone ( ) , fields[ 0 ] . dtype ( ) . clone ( ) ) )
356+ . collect ( ) ,
357+ ) ) ,
358+ Some ( dtype) => mapper. with_dtype ( DataType :: Struct (
359+ fields
360+ . iter ( )
361+ . skip ( usize:: from ( !include_init) )
362+ . map ( |f| Field :: new ( f. name ( ) . clone ( ) , dtype. clone ( ) ) )
363+ . collect ( ) ,
364+ ) ) ,
363365 } ,
364366
365367 MaxHorizontal => mapper. map_to_supertype ( ) ,
366368 MinHorizontal => mapper. map_to_supertype ( ) ,
367- SumHorizontal { .. } => {
368- mapper . map_to_supertype ( ) . map ( | mut f| {
369- if f. dtype == DataType :: Boolean {
370- f . dtype = IDX_DTYPE ;
371- }
372- f
373- } )
374- } ,
375- MeanHorizontal { .. } => {
376- mapper . map_to_supertype ( ) . map ( | mut f| {
377- match f . dtype {
378- dt @ DataType :: Float32 => { f . dtype = dt ; } ,
379- _ => { f. dtype = DataType :: Float64 ; } ,
380- } ;
381- f
382- } )
383- }
369+ SumHorizontal { .. } => mapper . map_to_supertype ( ) . map ( | mut f| {
370+ if f . dtype == DataType :: Boolean {
371+ f. dtype = IDX_DTYPE ;
372+ }
373+ f
374+ } ) ,
375+ MeanHorizontal { .. } => mapper . map_to_supertype ( ) . map ( | mut f| {
376+ match f . dtype {
377+ dt @ DataType :: Float32 => {
378+ f . dtype = dt ;
379+ } ,
380+ _ => {
381+ f. dtype = DataType :: Float64 ;
382+ } ,
383+ } ;
384+ f
385+ } ) ,
384386 #[ cfg( feature = "ewma" ) ]
385387 EwmMean { .. } => mapper. map_numeric_to_float_dtype ( true ) ,
386388 #[ cfg( feature = "ewma_by" ) ]
@@ -406,7 +408,12 @@ impl IRFunctionExpr {
406408 } ,
407409 ExtendConstant => mapper. with_same_dtype ( ) ,
408410
409- RowEncode ( ..) => mapper. try_map_field ( |_| Ok ( Field :: new ( PlSmallStr :: from_static ( "row_encoded" ) , DataType :: BinaryOffset ) ) ) ,
411+ RowEncode ( ..) => mapper. try_map_field ( |_| {
412+ Ok ( Field :: new (
413+ PlSmallStr :: from_static ( "row_encoded" ) ,
414+ DataType :: BinaryOffset ,
415+ ) )
416+ } ) ,
410417 #[ cfg( feature = "dtype-struct" ) ]
411418 RowDecode ( fields, _) => mapper. with_dtype ( DataType :: Struct ( fields. to_vec ( ) ) ) ,
412419 }
0 commit comments