Skip to content

Commit d07de76

Browse files
authored
Tensor::unfold(dim, size, step) (#3751)
* [WIP] towards pytorch.unfold() * torch * Expand unfold impl. - switched to pytorch's return shape. - added burn-router - exposed unfold calculation module. - ndarray and candle both need either upstream support or work-arounds. candle has a PR in-flight (from me): huggingface/candle#3091 * docs * Implement `unfold` operation for tensor backends. - Added `unfold` function for `candle`, `ndarray` backends with a copy-based implementation. - Updated function docs and tensor operation traits accordingly. - Incorporated window shape calculation for `unfold`. * Simplify `into_ranges` call for tensor shape calculation. * Remove redundant field repetition in `UnfoldOpIr` initialization. * book * Update `slice` implementation to use step-aware slice arguments in `unfold`. * rustfmt; fix rebase errors. * Optimize `unfold4d` implementation for zero-padding and unit-dilation cases. Update imports. * Refactor `unfold` implementation by introducing `calculate_unfold_shape`. - Replaced `calculate_unfold_windows` with `calculate_unfold_shape` across tensor backends. - Updated function documentation to reflect the new shape calculation. - Added unit tests for `calculate_unfold_shape`. - Simplified shape handling in tensor unfold operations.
1 parent 4002ecc commit d07de76

File tree

40 files changed

+979
-99
lines changed

40 files changed

+979
-99
lines changed

burn-book/src/building-blocks/tensor.md

Lines changed: 50 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -131,55 +131,56 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
131131

132132
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
133133

134-
| Burn | PyTorch Equivalent |
135-
|---------------------------------------------|---------------------------------------------------------------------------|
136-
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
137-
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
138-
| `Tensor::from_primitive(primitive)` | N/A |
139-
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
140-
| `tensor.all()` | `tensor.all()` |
141-
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
142-
| `tensor.any()` | `tensor.any()` |
143-
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
144-
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
145-
| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
146-
| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
147-
| `tensor.device()` | `tensor.device` |
148-
| `tensor.dtype()` | `tensor.dtype` |
149-
| `tensor.dims()` | `tensor.size()` |
150-
| `tensor.equal(other)` | `x == y` |
151-
| `tensor.expand(shape)` | `tensor.expand(shape)` |
152-
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
153-
| `tensor.flip(axes)` | `tensor.flip(axes)` |
154-
| `tensor.into_data()` | N/A |
155-
| `tensor.into_primitive()` | N/A |
156-
| `tensor.into_scalar()` | `tensor.item()` |
157-
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
158-
| `tensor.not_equal(other)` | `x != y` |
159-
| `tensor.permute(axes)` | `tensor.permute(axes)` |
160-
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
161-
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
162-
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
163-
| `tensor.reshape(shape)` | `tensor.view(shape)` |
164-
| `tensor.roll(shfts, dims)` | `tensor.roll(shifts, dims)` |
165-
| `tensor.roll_dim(shift, dim)` | `tensor.roll([shift], [dim])` |
166-
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
167-
| `tensor.select_assign(dim, indices, values)`| N/A |
168-
| `tensor.shape()` | `tensor.shape` |
169-
| `tensor.slice(s![range;step])` | `tensor[(*ranges,)]` or `tensor[start:end:step]` |
170-
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
171-
| `tensor.slice_fill(ranges, value)` | `tensor[(*ranges,)] = value` |
172-
| `tensor.slice_dim(dim, range)` | N/A |
173-
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
174-
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
175-
| `tensor.take(dim, indices)` | `numpy.take(tensor, indices, dim)` |
176-
| `tensor.to_data()` | N/A |
177-
| `tensor.to_device(device)` | `tensor.to(device)` |
178-
| `tensor.transpose()` | `tensor.T` |
179-
| `tensor.t()` | `tensor.T` |
180-
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
181-
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
182-
| `tensor.unsqueeze_dims(dims)` | N/A |
134+
| Burn | PyTorch Equivalent |
135+
|----------------------------------------------|---------------------------------------------------------------------------|
136+
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
137+
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
138+
| `Tensor::from_primitive(primitive)` | N/A |
139+
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
140+
| `tensor.all()` | `tensor.all()` |
141+
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
142+
| `tensor.any()` | `tensor.any()` |
143+
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
144+
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
145+
| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
146+
| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
147+
| `tensor.device()` | `tensor.device` |
148+
| `tensor.dtype()` | `tensor.dtype` |
149+
| `tensor.dims()` | `tensor.size()` |
150+
| `tensor.equal(other)` | `x == y` |
151+
| `tensor.expand(shape)` | `tensor.expand(shape)` |
152+
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
153+
| `tensor.flip(axes)` | `tensor.flip(axes)` |
154+
| `tensor.into_data()` | N/A |
155+
| `tensor.into_primitive()` | N/A |
156+
| `tensor.into_scalar()` | `tensor.item()` |
157+
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
158+
| `tensor.not_equal(other)` | `x != y` |
159+
| `tensor.permute(axes)` | `tensor.permute(axes)` |
160+
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
161+
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
162+
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
163+
| `tensor.reshape(shape)` | `tensor.view(shape)` |
164+
| `tensor.roll(shfts, dims)` | `tensor.roll(shifts, dims)` |
165+
| `tensor.roll_dim(shift, dim)` | `tensor.roll([shift], [dim])` |
166+
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
167+
| `tensor.select_assign(dim, indices, values)` | N/A |
168+
| `tensor.shape()` | `tensor.shape` |
169+
| `tensor.slice(s![range;step])` | `tensor[(*ranges,)]` or `tensor[start:end:step]` |
170+
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
171+
| `tensor.slice_fill(ranges, value)` | `tensor[(*ranges,)] = value` |
172+
| `tensor.slice_dim(dim, range)` | N/A |
173+
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
174+
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
175+
| `tensor.take(dim, indices)` | `numpy.take(tensor, indices, dim)` |
176+
| `tensor.to_data()` | N/A |
177+
| `tensor.to_device(device)` | `tensor.to(device)` |
178+
| `tensor.transpose()` | `tensor.T` |
179+
| `tensor.t()` | `tensor.T` |
180+
| `tensor.unfold(dim, size, step)` | `tensor.unfold(dim, size, step)` |
181+
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
182+
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
183+
| `tensor.unsqueeze_dims(dims)` | N/A |
183184

184185
### Numeric Operations
185186

crates/burn-autodiff/src/ops/bool_tensor.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,13 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
107107
fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
108108
B::bool_repeat_dim(tensor, dim, times)
109109
}
110+
111+
fn bool_unfold(
112+
tensor: BoolTensor<Self>,
113+
dim: usize,
114+
size: usize,
115+
step: usize,
116+
) -> BoolTensor<Self> {
117+
B::bool_unfold(tensor, dim, size, step)
118+
}
110119
}

crates/burn-autodiff/src/ops/int_tensor.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,4 +377,13 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
377377
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
378378
B::int_cast(tensor, dtype)
379379
}
380+
381+
fn int_unfold(
382+
tensor: IntTensor<Self>,
383+
dim: usize,
384+
size: usize,
385+
step: usize,
386+
) -> IntTensor<Self> {
387+
B::int_unfold(tensor, dim, size, step)
388+
}
380389
}

crates/burn-autodiff/src/ops/tensor.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2592,6 +2592,15 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
25922592

25932593
// TODO: Implement float_prod and float_sum
25942594
// https://github.com/tracel-ai/burn/issues/1458
2595+
2596+
fn float_unfold(
2597+
tensor: FloatTensor<Self>,
2598+
dim: usize,
2599+
size: usize,
2600+
step: usize,
2601+
) -> FloatTensor<Self> {
2602+
AutodiffTensor::new(B::float_unfold(tensor.primitive, dim, size, step))
2603+
}
25952604
}
25962605

25972606
#[derive(Debug, Clone)]

crates/burn-candle/src/ops/base.rs

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
use std::cmp::max;
12
use std::marker::PhantomData;
23

3-
use burn_tensor::{Element, Shape, TensorData, TensorMetadata, backend::Backend};
4-
use candle_core::WithDType;
5-
use half::{bf16, f16};
6-
74
use crate::{
85
Candle, CandleDevice, CandleTensor,
96
element::{CandleElement, FloatCandleElement, IntCandleElement},
107
};
8+
use burn_tensor::ops::unfold::{calculate_unfold_shape, calculate_unfold_windows};
9+
use burn_tensor::{Element, Shape, TensorData, TensorMetadata, backend::Backend};
10+
use candle_core::{Layout, WithDType};
11+
use half::{bf16, f16};
1112

1213
use super::tensor;
1314

@@ -193,6 +194,29 @@ pub fn expand(tensor: CandleTensor, shape: Shape) -> CandleTensor {
193194
CandleTensor::new(tensor.tensor.broadcast_as(shape.dims).unwrap())
194195
}
195196

197+
pub fn unfold(tensor: CandleTensor, dim: usize, size: usize, step: usize) -> CandleTensor {
198+
let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
199+
let windows = result_shape[dim];
200+
201+
let mut select_ranges = tensor.shape().into_ranges();
202+
let new_axis = select_ranges.len();
203+
204+
let mut stack = Vec::with_capacity(windows);
205+
for widx in 0..windows {
206+
let start = widx * step;
207+
let end = start + size;
208+
select_ranges[dim] = start..end;
209+
210+
let mut window_slice = slice(tensor.clone(), &select_ranges);
211+
212+
window_slice = swap_dims(window_slice, dim, new_axis);
213+
let window_slice = CandleTensor::new(window_slice.tensor.unsqueeze(new_axis).unwrap());
214+
215+
stack.push(window_slice);
216+
}
217+
cat(stack, dim)
218+
}
219+
196220
pub fn sign(tensor: CandleTensor) -> CandleTensor {
197221
CandleTensor::new(tensor.tensor.sign().unwrap())
198222
}

crates/burn-candle/src/ops/bool_tensor.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
element::{CandleElement, FloatCandleElement, IntCandleElement},
1010
};
1111

12-
use super::base::{expand, permute};
12+
use super::base::{expand, permute, unfold};
1313

1414
impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
1515
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
@@ -136,4 +136,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
136136
fn bool_expand(tensor: BoolTensor<Self>, shape: Shape) -> BoolTensor<Self> {
137137
expand(tensor, shape)
138138
}
139+
140+
fn bool_unfold(
141+
tensor: BoolTensor<Self>,
142+
dim: usize,
143+
size: usize,
144+
step: usize,
145+
) -> BoolTensor<Self> {
146+
unfold(tensor, dim, size, step)
147+
}
139148
}

crates/burn-candle/src/ops/int_tensor.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::{
99
element::{CandleElement, FloatCandleElement, IntCandleElement},
1010
};
1111

12-
use super::base::{expand, permute, sign};
12+
use super::base::{expand, permute, sign, unfold};
1313

1414
impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
1515
fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
@@ -384,6 +384,15 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
384384
expand(tensor, shape)
385385
}
386386

387+
fn int_unfold(
388+
tensor: IntTensor<Self>,
389+
dim: usize,
390+
size: usize,
391+
step: usize,
392+
) -> IntTensor<Self> {
393+
unfold(tensor, dim, size, step)
394+
}
395+
387396
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
388397
sign(tensor)
389398
}

crates/burn-candle/src/ops/tensor.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use crate::{
1212
element::{CandleElement, FloatCandleElement, IntCandleElement},
1313
};
1414

15-
use super::base::{expand, permute, sign};
15+
use super::base::{expand, permute, sign, unfold};
1616

1717
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
1818
fn float_from_data(data: TensorData, device: &Device<Self>) -> CandleTensor {
@@ -460,6 +460,15 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
460460
expand(tensor, shape)
461461
}
462462

463+
fn float_unfold(
464+
tensor: FloatTensor<Self>,
465+
dim: usize,
466+
size: usize,
467+
step: usize,
468+
) -> FloatTensor<Self> {
469+
unfold(tensor, dim, size, step)
470+
}
471+
463472
fn float_sign(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
464473
sign(tensor)
465474
}

crates/burn-cubecl/src/ops/base.rs

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::{CubeRuntime, element::CubeElement, kernel, tensor::CubeTensor};
22
use burn_common::tensor::{ReshapeAction, reshape_action};
3+
use burn_tensor::ops::unfold::calculate_unfold_shape;
34
use burn_tensor::{
45
Shape, TensorData,
56
quantization::{QTensorPrimitive, QuantLevel},
@@ -213,3 +214,44 @@ pub(crate) fn max_line_size_many<R: CubeRuntime>(tensors: &[&CubeTensor<R>], dim
213214

214215
vec.unwrap_or(0)
215216
}
217+
218+
/// Unfold windows along a dimension.
219+
///
220+
/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
221+
/// where windows are advanced by `step` at each index.
222+
///
223+
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
224+
///
225+
/// The new view will have the unfolded dimension replaced by two dimensions;
226+
/// one in the position of the original dimension, with size equal to the number of windows,
227+
/// and one appended to the right-most position, with size equal to `size`.
228+
///
229+
/// # Arguments
230+
///
231+
/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
232+
/// * `dim` - the dimension to unfold.
233+
/// * `size` - the size of each unfolded window.
234+
/// * `step` - the step between each window.
235+
///
236+
/// # Returns
237+
///
238+
/// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
239+
pub fn unfold<R: CubeRuntime>(
240+
tensor: CubeTensor<R>,
241+
dim: usize,
242+
size: usize,
243+
step: usize,
244+
) -> CubeTensor<R> {
245+
let shape = calculate_unfold_shape(tensor.shape, dim, size, step);
246+
247+
let d_stride = tensor.strides[dim];
248+
let mut strides = tensor.strides.clone();
249+
strides[dim] = step * d_stride;
250+
strides.push(d_stride);
251+
252+
CubeTensor {
253+
shape: shape.into(),
254+
strides,
255+
..tensor
256+
}
257+
}

0 commit comments

Comments
 (0)