Skip to content

Commit e6f9b5a

Browse files
authored
Merge pull request #67 from v0lta/axis-support
Axis support
2 parents 71b6ccc + 866c4e5 commit e6f9b5a

20 files changed

+1035
-362
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name: Tests
22

3-
on: [ push, pull_request ]
3+
on: [ push ]
44

55
jobs:
66
tests:
@@ -20,7 +20,7 @@ jobs:
2020
run: pip install nox
2121
- name: Test with pytest
2222
run:
23-
nox -s test
23+
nox -s fast-test
2424
lint:
2525
name: lint
2626
runs-on: ubuntu-latest

README.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
:alt: GitHub Actions
1313

1414
.. image:: https://readthedocs.org/projects/pytorch-wavelet-toolbox/badge/?version=latest
15-
:target: https://pytorch-wavelet-toolbox.readthedocs.io/en/latest/?badge=latest
15+
:target: https://pytorch-wavelet-toolbox.readthedocs.io/en/latest/ptwt.html
1616
:alt: Documentation Status
1717

1818
.. image:: https://img.shields.io/pypi/pyversions/ptwt

src/ptwt/_util.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Utility methods to compute wavelet decompositions from a dataset."""
2-
from typing import List, Optional, Protocol, Sequence, Tuple, Union
2+
from typing import Any, Callable, List, Optional, Protocol, Sequence, Tuple, Union
33

4+
import numpy as np
45
import pywt
56
import torch
67

@@ -103,19 +104,80 @@ def _pad_symmetric(
103104
return signal
104105

105106

106-
def _fold_channels(data: torch.Tensor) -> torch.Tensor:
107-
"""Fold [batch, channel, height width] into [batch*channel, height, widht]."""
108-
ds = data.shape
109-
return torch.reshape(
110-
data,
111-
[
112-
ds[0] * ds[1],
113-
ds[2],
114-
ds[3],
115-
],
107+
def _fold_axes(data: torch.Tensor, keep_no: int) -> Tuple[torch.Tensor, List[int]]:
108+
"""Fold unchanged leading dimensions into a single batch dimension.
109+
110+
Args:
111+
data ( torch.Tensor): The input data array.
112+
keep_no (int): The number of dimensions to keep.
113+
114+
Returns:
115+
Tuple[ torch.Tensor, List[int]]:
116+
The folded result array, and the shape of the original input.
117+
"""
118+
dshape = list(data.shape)
119+
return (
120+
torch.reshape(data, [int(np.prod(dshape[:-keep_no]))] + dshape[-keep_no:]),
121+
dshape,
116122
)
117123

118124

119-
def _unfold_channels(data: torch.Tensor, ds: List[int]) -> torch.Tensor:
120-
"""Unfold [batch*channel, height, widht] into [batch, channel, height, width]."""
121-
return torch.reshape(data, [ds[0], ds[1], data.shape[1], data.shape[2]])
125+
def _unfold_axes(data: torch.Tensor, ds: List[int], keep_no: int) -> torch.Tensor:
126+
"""Unfold i.e. [batch*channel,height,widht] to [batch,channel,height,width]."""
127+
return torch.reshape(data, ds[:-keep_no] + list(data.shape[-keep_no:]))
128+
129+
130+
def _check_if_tensor(array: Any) -> torch.Tensor:
131+
if not isinstance(array, torch.Tensor):
132+
raise ValueError(
133+
"First element of coeffs must be the approximation coefficient tensor."
134+
)
135+
return array
136+
137+
138+
def _check_axes_argument(axes: List[int]) -> None:
139+
if len(set(axes)) != len(axes):
140+
raise ValueError("Cant transform the same axis twice.")
141+
142+
143+
def _get_transpose_order(
144+
axes: List[int], data_shape: List[int]
145+
) -> Tuple[List[int], List[int]]:
146+
axes = list(map(lambda a: a + len(data_shape) if a < 0 else a, axes))
147+
all_axes = list(range(len(data_shape)))
148+
remove_transformed = list(filter(lambda a: a not in axes, all_axes))
149+
return remove_transformed, axes
150+
151+
152+
def _swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor:
153+
_check_axes_argument(axes)
154+
front, back = _get_transpose_order(axes, list(data.shape))
155+
return torch.permute(data, front + back)
156+
157+
158+
def _undo_swap_axes(data: torch.Tensor, axes: List[int]) -> torch.Tensor:
159+
_check_axes_argument(axes)
160+
front, back = _get_transpose_order(axes, list(data.shape))
161+
restore_sorted = torch.argsort(torch.tensor(front + back)).tolist()
162+
return torch.permute(data, restore_sorted)
163+
164+
165+
def _map_result(
166+
data: List[Union[torch.Tensor, Any]], # following jax tree_map typing can be Any
167+
function: Callable[[Any], torch.Tensor],
168+
) -> List[Union[torch.Tensor, Any]]:
169+
# Apply the given function to the input list of tensor and tuples.
170+
result_lst: List[Union[torch.Tensor, Any]] = []
171+
for element in data:
172+
if isinstance(element, torch.Tensor):
173+
result_lst.append(function(element))
174+
elif isinstance(element, tuple):
175+
result_lst.append(
176+
(function(element[0]), function(element[1]), function(element[2]))
177+
)
178+
elif isinstance(element, dict):
179+
new_dict = {}
180+
for key, value in element.items():
181+
new_dict[key] = function(value)
182+
result_lst.append(new_dict)
183+
return result_lst

src/ptwt/continuous_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,8 @@ def wavefun(
270270
"""Define a grid and evaluate the wavelet on it."""
271271
length = 2**precision
272272
# load the bounds from untyped pywt code.
273-
lower_bound: float = float(self.lower_bound) # type: ignore
274-
upper_bound: float = float(self.upper_bound) # type: ignore
273+
lower_bound: float = float(self.lower_bound)
274+
upper_bound: float = float(self.upper_bound)
275275
grid = torch.linspace(
276276
lower_bound,
277277
upper_bound,

src/ptwt/conv_transform.py

Lines changed: 89 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from ._util import (
1212
Wavelet,
1313
_as_wavelet,
14-
_fold_channels,
14+
_fold_axes,
1515
_get_len,
1616
_is_dtype_supported,
1717
_pad_symmetric,
18-
_unfold_channels,
18+
_unfold_axes,
1919
)
2020

2121

@@ -217,47 +217,66 @@ def _adjust_padding_at_reconstruction(
217217
return pad_end, pad_start
218218

219219

220-
def _wavedec_fold_channels_1d(data: torch.Tensor) -> Tuple[torch.Tensor, List[int]]:
221-
data = data.unsqueeze(-2)
222-
ds = data.shape
223-
data = _fold_channels(data)
224-
return data, list(ds)
220+
def _preprocess_tensor_dec1d(
221+
data: torch.Tensor,
222+
) -> Tuple[torch.Tensor, Union[List[int], None]]:
223+
"""Preprocess input tensor dimensions.
225224
225+
Args:
226+
data (torch.Tensor): An input tensor of any shape.
226227
227-
def _wavedec_unfold_channels_1d_list(
228-
result_list: List[torch.Tensor], ds: List[int]
228+
Returns:
229+
Tuple[torch.Tensor, Union[List[int], None]]:
230+
A data tensor of shape [new_batch, 1, to_process]
231+
and the original shape, if the shape has changed.
232+
"""
233+
ds = None
234+
if len(data.shape) == 1:
235+
# assume time series
236+
data = data.unsqueeze(0).unsqueeze(0)
237+
elif len(data.shape) == 2:
238+
# assume batched time series
239+
data = data.unsqueeze(1)
240+
else:
241+
data, ds = _fold_axes(data, 1)
242+
data = data.unsqueeze(1)
243+
return data, ds
244+
245+
246+
def _postprocess_result_list_dec1d(
247+
result_lst: List[torch.Tensor], ds: List[int]
229248
) -> List[torch.Tensor]:
230-
unfold_res = []
231-
for res_coeff in result_list:
232-
unfold_res.append(
233-
_unfold_channels(res_coeff.unsqueeze(1), list(ds)).squeeze(-2)
234-
)
235-
return unfold_res
249+
# Unfold axes for the wavelets
250+
unfold_list = []
251+
for fres in result_lst:
252+
unfold_list.append(_unfold_axes(fres, ds, 1))
253+
return unfold_list
236254

237255

238-
def _waverec_fold_channels_1d_list(
239-
coeff_list: List[torch.Tensor],
256+
def _preprocess_result_list_rec1d(
257+
result_lst: List[torch.Tensor],
240258
) -> Tuple[List[torch.Tensor], List[int]]:
241-
folded = []
242-
ds = coeff_list[0].unsqueeze(-2).shape
243-
for to_fold_coeff in coeff_list:
244-
folded.append(_fold_channels(to_fold_coeff.unsqueeze(-2)).squeeze(-2))
245-
return folded, list(ds)
259+
# Fold axes for the wavelets
260+
fold_coeffs = []
261+
ds = list(result_lst[0].shape)
262+
for uf_coeff in result_lst:
263+
f_coeff, _ = _fold_axes(uf_coeff, 1)
264+
fold_coeffs.append(f_coeff)
265+
return fold_coeffs, ds
246266

247267

248268
def wavedec(
249269
data: torch.Tensor,
250270
wavelet: Union[Wavelet, str],
251271
mode: str = "reflect",
252272
level: Optional[int] = None,
273+
axis: int = -1,
253274
) -> List[torch.Tensor]:
254275
"""Compute the analysis (forward) 1d fast wavelet transform.
255276
256277
Args:
257278
data (torch.Tensor): The input time series,
258-
1d inputs are interpreted as ``[time]``,
259-
2d inputs as ``[batch_size, time]``,
260-
and 3d inputs as ``[batch_size, channels, time]``.
279+
By default the last axis is transformed.
261280
wavelet (Wavelet or str): A pywt wavelet compatible object or
262281
the name of a pywt wavelet.
263282
Please consider the output from ``pywt.wavelist(kind='discrete')``
@@ -274,9 +293,11 @@ def wavedec(
274293
Zero padding pads zeros.
275294
Constant padding replicates border values.
276295
Periodic padding cyclically repeats samples.
277-
278296
level (int): The scale level to be computed.
279297
Defaults to None.
298+
axis (int): Compute the transform over this axis instead of the
299+
last one. Defaults to -1.
300+
280301
281302
Returns:
282303
list: A list::
@@ -287,7 +308,8 @@ def wavedec(
287308
approximation and D detail coefficients.
288309
289310
Raises:
290-
ValueError: If the dtype of the input data tensor is unsupported.
311+
ValueError: If the dtype of the input data tensor is unsupported or
312+
if more than one axis is provided.
291313
292314
Example:
293315
>>> import torch
@@ -300,21 +322,17 @@ def wavedec(
300322
>>> ptwt.wavedec(data_torch, pywt.Wavelet('haar'),
301323
>>> mode='zero', level=2)
302324
"""
303-
fold = False
304-
if data.dim() == 1:
305-
# assume time series
306-
data = data.unsqueeze(0).unsqueeze(0)
307-
elif data.dim() == 2:
308-
# assume batched time series
309-
data = data.unsqueeze(1)
310-
elif data.dim() == 3:
311-
# assume batch, channels, time -> fold channels
312-
fold = True
313-
data, ds = _wavedec_fold_channels_1d(data)
325+
if axis != -1:
326+
if isinstance(axis, int):
327+
data = data.swapaxes(axis, -1)
328+
else:
329+
raise ValueError("wavedec transforms a single axis only.")
314330

315331
if not _is_dtype_supported(data.dtype):
316332
raise ValueError(f"Input dtype {data.dtype} not supported")
317333

334+
data, ds = _preprocess_tensor_dec1d(data)
335+
318336
dec_lo, dec_hi, _, _ = _get_filter_tensors(
319337
wavelet, flip=True, device=data.device, dtype=data.dtype
320338
)
@@ -332,28 +350,38 @@ def wavedec(
332350
res_lo, res_hi = torch.split(res, 1, 1)
333351
result_list.append(res_hi.squeeze(1))
334352
result_list.append(res_lo.squeeze(1))
353+
result_list.reverse()
354+
355+
if ds:
356+
result_list = _postprocess_result_list_dec1d(result_list, ds)
335357

336-
# unfold if necessary
337-
if fold:
338-
result_list = _wavedec_unfold_channels_1d_list(result_list, ds)
358+
if axis != -1:
359+
swap = []
360+
for coeff in result_list:
361+
swap.append(coeff.swapaxes(axis, -1))
362+
result_list = swap
339363

340-
return result_list[::-1]
364+
return result_list
341365

342366

343-
def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.Tensor:
367+
def waverec(
368+
coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str], axis: int = -1
369+
) -> torch.Tensor:
344370
"""Reconstruct a signal from wavelet coefficients.
345371
346372
Args:
347373
coeffs (list): The wavelet coefficient list produced by wavedec.
348374
wavelet (Wavelet or str): A pywt wavelet compatible object or
349375
the name of a pywt wavelet.
376+
axis (int): Transform this axis instead of the last one. Defaults to -1.
350377
351378
Returns:
352379
torch.Tensor: The reconstructed signal.
353380
354381
Raises:
355382
ValueError: If the dtype of the coeffs tensor is unsupported or if the
356-
coefficients have incompatible shapes, dtypes or devices.
383+
coefficients have incompatible shapes, dtypes or devices or if
384+
more than one axis is provided.
357385
358386
Example:
359387
>>> import torch
@@ -379,11 +407,19 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
379407
elif torch_dtype != coeff.dtype:
380408
raise ValueError("coefficients must have the same dtype")
381409

410+
if axis != -1:
411+
swap = []
412+
if isinstance(axis, int):
413+
for coeff in coeffs:
414+
swap.append(coeff.swapaxes(axis, -1))
415+
coeffs = swap
416+
else:
417+
raise ValueError("waverec transforms a single axis only.")
418+
382419
# fold channels, if necessary.
383-
fold = False
384-
if coeffs[0].dim() == 3:
385-
fold = True
386-
coeffs, ds = _waverec_fold_channels_1d_list(coeffs)
420+
ds = None
421+
if coeffs[0].dim() >= 3:
422+
coeffs, ds = _preprocess_result_list_rec1d(coeffs)
387423

388424
_, _, rec_lo, rec_hi = _get_filter_tensors(
389425
wavelet, flip=False, device=torch_device, dtype=torch_dtype
@@ -408,7 +444,10 @@ def waverec(coeffs: List[torch.Tensor], wavelet: Union[Wavelet, str]) -> torch.T
408444
if padr > 0:
409445
res_lo = res_lo[..., :-padr]
410446

411-
if fold:
412-
res_lo = _unfold_channels(res_lo.unsqueeze(-2), list(ds)).squeeze(-2)
447+
if ds:
448+
res_lo = _unfold_axes(res_lo, ds, 1)
449+
450+
if axis != -1:
451+
res_lo = res_lo.swapaxes(axis, -1)
413452

414453
return res_lo

0 commit comments

Comments
 (0)