11"""Ensure pytorch's torch.jit.trace feature works properly."""
22
3- from typing import NamedTuple , Optional , Union
3+ from typing import Optional , Union
44
55import numpy as np
66import pytest
1313from tests ._mackey_glass import MackeyGenerator
1414
1515
16- class WaveletTuple (NamedTuple ):
17- """Replaces namedtuple("Wavelet", ("dec_lo", "dec_hi", "rec_lo", "rec_hi"))."""
18-
19- dec_lo : torch .Tensor
20- dec_hi : torch .Tensor
21- rec_lo : torch .Tensor
22- rec_hi : torch .Tensor
23-
24-
25- def _set_up_wavelet_tuple (wavelet : WaveletTuple , dtype : torch .dtype ) -> WaveletTuple :
26- return WaveletTuple (
27- torch .tensor (wavelet .dec_lo ).type (dtype ),
28- torch .tensor (wavelet .dec_hi ).type (dtype ),
29- torch .tensor (wavelet .rec_lo ).type (dtype ),
30- torch .tensor (wavelet .rec_hi ).type (dtype ),
31- )
32-
33-
3416def _to_jit_wavedec_fun (
3517 data : torch .Tensor , wavelet : Union [ptwt .Wavelet , str ], level : Optional [int ]
3618) -> list [torch .Tensor ]:
@@ -53,7 +35,7 @@ def test_conv_fwt_jit(
5335
5436 mackey_data_1 = torch .squeeze (generator (), - 1 ).type (dtype )
5537 wavelet = pywt .Wavelet (wavelet_string )
56- wavelet = _set_up_wavelet_tuple (wavelet , dtype )
38+ wavelet = ptwt . WaveletTensorTuple . from_wavelet (wavelet , dtype )
5739
5840 with pytest .warns (Warning ):
5941 jit_wavedec = torch .jit .trace ( # type: ignore
@@ -105,7 +87,7 @@ def test_conv_fwt_jit_2d() -> None:
10587 rec = _to_jit_waverec_2 (coeff , wavelet )
10688 assert np .allclose (rec .squeeze (1 ).numpy (), data .numpy ())
10789
108- wavelet = _set_up_wavelet_tuple (wavelet , dtype = torch .float64 )
90+ wavelet = ptwt . WaveletTensorTuple . from_wavelet (wavelet , dtype = torch .float64 )
10991 with pytest .warns (Warning ):
11092 jit_wavedec2 = torch .jit .trace ( # type: ignore
11193 _to_jit_wavedec_2 ,
@@ -159,7 +141,7 @@ def test_conv_fwt_jit_3d() -> None:
159141 rec = _to_jit_waverec_3 (coeff , wavelet )
160142 assert np .allclose (rec .squeeze (1 ).numpy (), data .numpy ())
161143
162- wavelet = _set_up_wavelet_tuple (wavelet , dtype = torch .float64 )
144+ wavelet = ptwt . WaveletTensorTuple . from_wavelet (wavelet , dtype = torch .float64 )
163145 with pytest .warns (Warning ):
164146 jit_wavedec3 = torch .jit .trace ( # type: ignore
165147 _to_jit_wavedec_3 ,
0 commit comments