Skip to content

Commit 23af416

Browse files
authored
fix: Do not allow inferring (-1) the dimension on any Expr.reshape dimension except the first (pola-rs#24591)
1 parent 8d64e3d commit 23af416

File tree

3 files changed

+98
-75
lines changed

3 files changed

+98
-75
lines changed

crates/polars-plan/src/plans/aexpr/function_expr/schema.rs

Lines changed: 79 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,10 @@ impl IRFunctionExpr {
5656
#[cfg(feature = "trigonometry")]
5757
Atan2 => mapper.map_to_float_dtype(),
5858
#[cfg(feature = "sign")]
59-
Sign => mapper.ensure_satisfies(|_, dtype| dtype.is_primitive_numeric(), "sign")?.with_same_dtype(),
60-
FillNull => mapper.map_to_supertype(),
59+
Sign => mapper
60+
.ensure_satisfies(|_, dtype| dtype.is_primitive_numeric(), "sign")?
61+
.with_same_dtype(),
62+
FillNull => mapper.map_to_supertype(),
6163
#[cfg(feature = "rolling_window")]
6264
RollingExpr { function, options } => {
6365
use IRRollingFunction::*;
@@ -67,7 +69,7 @@ impl IRFunctionExpr {
6769
Var => mapper.var_dtype(),
6870
Sum => mapper.sum_dtype(),
6971
#[cfg(feature = "cov")]
70-
CorrCov {..} => mapper.map_to_float_dtype(),
72+
CorrCov { .. } => mapper.map_to_float_dtype(),
7173
#[cfg(feature = "moment")]
7274
Skew | Kurtosis => mapper.map_to_float_dtype(),
7375
Map(_) => mapper.try_map_field(|field| {
@@ -84,20 +86,22 @@ impl IRFunctionExpr {
8486
}
8587
},
8688
#[cfg(feature = "rolling_window_by")]
87-
RollingExprBy{function_by, ..} => {
89+
RollingExprBy { function_by, .. } => {
8890
use IRRollingFunctionBy::*;
8991
match function_by {
9092
MinBy | MaxBy => mapper.with_same_dtype(),
91-
MeanBy | QuantileBy | StdBy=> mapper.moment_dtype(),
93+
MeanBy | QuantileBy | StdBy => mapper.moment_dtype(),
9294
VarBy => mapper.var_dtype(),
9395
SumBy => mapper.sum_dtype(),
9496
}
9597
},
9698
Rechunk => mapper.with_same_dtype(),
97-
Append { upcast } => if *upcast {
98-
mapper.map_to_supertype()
99-
} else {
100-
mapper.with_same_dtype()
99+
Append { upcast } => {
100+
if *upcast {
101+
mapper.map_to_supertype()
102+
} else {
103+
mapper.with_same_dtype()
104+
}
101105
},
102106
ShiftAndFill => mapper.with_same_dtype(),
103107
DropNans => mapper.with_same_dtype(),
@@ -131,10 +135,16 @@ impl IRFunctionExpr {
131135
#[cfg(feature = "dtype-struct")]
132136
AsStruct => {
133137
let mut field_names = PlHashSet::with_capacity(fields.len() - 1);
134-
let struct_fields = fields.iter().map(|f| {
135-
polars_ensure!(field_names.insert(f.name.as_str()), duplicate_field = f.name());
136-
Ok(f.clone())
137-
}).collect::<PolarsResult<Vec<_>>>()?;
138+
let struct_fields = fields
139+
.iter()
140+
.map(|f| {
141+
polars_ensure!(
142+
field_names.insert(f.name.as_str()),
143+
duplicate_field = f.name()
144+
);
145+
Ok(f.clone())
146+
})
147+
.collect::<PolarsResult<Vec<_>>>()?;
138148
Ok(Field::new(
139149
fields[0].name().clone(),
140150
DataType::Struct(struct_fields),
@@ -265,41 +275,14 @@ impl IRFunctionExpr {
265275
RepeatBy => mapper.map_dtype(|dt| DataType::List(dt.clone().into())),
266276
#[cfg(feature = "dtype-array")]
267277
Reshape(dims) => mapper.try_map_dtype(|dt: &DataType| {
268-
let dtype = dt.inner_dtype().unwrap_or(dt).clone();
269-
270-
if dims.len() == 1 {
271-
return Ok(dtype);
272-
}
273-
274-
let num_infers = dims.iter().filter(|d| matches!(d, ReshapeDimension::Infer)).count();
275-
276-
polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");
277-
278-
let mut inferred_size = 0;
279-
if num_infers == 1 {
280-
let mut total_size = 1u64;
281-
let mut current = dt;
282-
while let DataType::Array(dt, width) = current {
283-
if *width == 0 {
284-
total_size = 0;
285-
break;
286-
}
287-
288-
current = dt.as_ref();
289-
total_size *= *width as u64;
290-
}
291-
292-
let current_size = dims.iter().map(|d| d.get_or_infer(1)).product::<u64>();
293-
inferred_size = total_size / current_size;
294-
}
295-
296-
let mut prev_dtype = dtype.leaf_dtype().clone();
297-
298-
// We pop the outer dimension as that is the height of the series.
299-
for dim in &dims[1..] {
300-
prev_dtype = DataType::Array(Box::new(prev_dtype), dim.get_or_infer(inferred_size) as usize);
278+
let mut wrapped_dtype = dt.leaf_dtype().clone();
279+
for dim in dims[1..].iter().rev() {
280+
let Some(array_size) = dim.get() else {
281+
polars_bail!(InvalidOperation: "can only infer the first dimension");
282+
};
283+
wrapped_dtype = DataType::Array(Box::new(wrapped_dtype), array_size as usize);
301284
}
302-
Ok(prev_dtype)
285+
Ok(wrapped_dtype)
303286
}),
304287
#[cfg(feature = "cutqcut")]
305288
QCut {
@@ -350,37 +333,56 @@ impl IRFunctionExpr {
350333
Some(dtype) => mapper.with_dtype(dtype.clone()),
351334
},
352335
#[cfg(feature = "dtype-struct")]
353-
CumReduceHorizontal {
354-
return_dtype, ..
355-
}=> match return_dtype {
336+
CumReduceHorizontal { return_dtype, .. } => match return_dtype {
356337
None => mapper.with_dtype(DataType::Struct(fields.to_vec())),
357-
Some(dtype) => mapper.with_dtype(DataType::Struct(fields.iter().map(|f| Field::new(f.name().clone(), dtype.clone())).collect())),
338+
Some(dtype) => mapper.with_dtype(DataType::Struct(
339+
fields
340+
.iter()
341+
.map(|f| Field::new(f.name().clone(), dtype.clone()))
342+
.collect(),
343+
)),
358344
},
359345
#[cfg(feature = "dtype-struct")]
360-
CumFoldHorizontal { return_dtype, include_init, .. } => match return_dtype {
361-
None => mapper.with_dtype(DataType::Struct(fields.iter().skip(usize::from(!include_init)).map(|f| Field::new(f.name().clone(), fields[0].dtype().clone())).collect())),
362-
Some(dtype) => mapper.with_dtype(DataType::Struct(fields.iter().skip(usize::from(!include_init)).map(|f| Field::new(f.name().clone(), dtype.clone())).collect())),
346+
CumFoldHorizontal {
347+
return_dtype,
348+
include_init,
349+
..
350+
} => match return_dtype {
351+
None => mapper.with_dtype(DataType::Struct(
352+
fields
353+
.iter()
354+
.skip(usize::from(!include_init))
355+
.map(|f| Field::new(f.name().clone(), fields[0].dtype().clone()))
356+
.collect(),
357+
)),
358+
Some(dtype) => mapper.with_dtype(DataType::Struct(
359+
fields
360+
.iter()
361+
.skip(usize::from(!include_init))
362+
.map(|f| Field::new(f.name().clone(), dtype.clone()))
363+
.collect(),
364+
)),
363365
},
364366

365367
MaxHorizontal => mapper.map_to_supertype(),
366368
MinHorizontal => mapper.map_to_supertype(),
367-
SumHorizontal { .. } => {
368-
mapper.map_to_supertype().map(|mut f| {
369-
if f.dtype == DataType::Boolean {
370-
f.dtype = IDX_DTYPE;
371-
}
372-
f
373-
})
374-
},
375-
MeanHorizontal { .. } => {
376-
mapper.map_to_supertype().map(|mut f| {
377-
match f.dtype {
378-
dt @ DataType::Float32 => { f.dtype = dt; },
379-
_ => { f.dtype = DataType::Float64; },
380-
};
381-
f
382-
})
383-
}
369+
SumHorizontal { .. } => mapper.map_to_supertype().map(|mut f| {
370+
if f.dtype == DataType::Boolean {
371+
f.dtype = IDX_DTYPE;
372+
}
373+
f
374+
}),
375+
MeanHorizontal { .. } => mapper.map_to_supertype().map(|mut f| {
376+
match f.dtype {
377+
dt @ DataType::Float32 => {
378+
f.dtype = dt;
379+
},
380+
_ => {
381+
f.dtype = DataType::Float64;
382+
},
383+
};
384+
f
385+
}),
384386
#[cfg(feature = "ewma")]
385387
EwmMean { .. } => mapper.map_numeric_to_float_dtype(true),
386388
#[cfg(feature = "ewma_by")]
@@ -406,7 +408,12 @@ impl IRFunctionExpr {
406408
},
407409
ExtendConstant => mapper.with_same_dtype(),
408410

409-
RowEncode(..) => mapper.try_map_field(|_| Ok(Field::new(PlSmallStr::from_static("row_encoded"), DataType::BinaryOffset))),
411+
RowEncode(..) => mapper.try_map_field(|_| {
412+
Ok(Field::new(
413+
PlSmallStr::from_static("row_encoded"),
414+
DataType::BinaryOffset,
415+
))
416+
}),
410417
#[cfg(feature = "dtype-struct")]
411418
RowDecode(fields, _) => mapper.with_dtype(DataType::Struct(fields.to_vec())),
412419
}

py-polars/polars/expr/expr.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9469,8 +9469,10 @@ def reshape(self, dimensions: tuple[int, ...]) -> Expr:
94699469
Parameters
94709470
----------
94719471
dimensions
9472-
Tuple of the dimension sizes. If a -1 is used in any of the dimensions, that
9473-
dimension is inferred.
9472+
Tuple of the dimension sizes. If -1 is used as the value for the
9473+
first dimension, that dimension is inferred.
9474+
Because the size of the Column may not be known in advance, it is
9475+
only possible to use -1 for the first dimension.
94749476
94759477
Returns
94769478
-------

py-polars/tests/unit/operations/test_reshape.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ def display_shape(shape: tuple[int, ...]) -> str:
1414

1515

1616
def test_reshape() -> None:
17-
s = pl.Series("a", [1, 2, 3, 4])
17+
df = pl.DataFrame({"a": [1, 2, 3, 4]})
18+
s = df.to_series()
1819
out = s.reshape((-1, 2))
1920
expected = pl.Series("a", [[1, 2], [3, 4]], dtype=pl.Array(pl.Int64, 2))
2021
assert_series_equal(out, expected)
@@ -47,6 +48,19 @@ def test_reshape() -> None:
4748
):
4849
s.reshape(())
4950

51+
# expr inferred dimension on non-first dimension
52+
with pytest.raises(
53+
InvalidOperationError, match="can only infer the first dimension"
54+
):
55+
df.select(pl.col("a").reshape((2, -1)))
56+
57+
df = pl.DataFrame({"a": list(range(2 * 3 * 5 * 7))})
58+
q1 = df.lazy().select(pl.col("a").reshape((3, 5, 7, 2)))
59+
q2 = df.lazy().select(pl.col("a").reshape((-1, 5, 7, 2)))
60+
assert q1.collect_schema() == q1.collect().schema
61+
assert q2.collect_schema() == q2.collect().schema
62+
assert q1.collect_schema() == q2.collect_schema()
63+
5064

5165
@pytest.mark.parametrize("shape", [(1, 3), (5, 1), (-1, 5), (3, -1)])
5266
def test_reshape_invalid_dimension_size(shape: tuple[int, ...]) -> None:

0 commit comments

Comments
 (0)