|
23 | 23 | import tensorflow.compat.v1 as tf
|
24 | 24 |
|
25 | 25 |
|
26 |
| -_matrix_cache = {} |
27 |
| - |
28 |
| - |
29 | 26 | __all__ = [
|
30 | 27 | "irdft_matrix",
|
31 | 28 | ]
|
@@ -60,23 +57,18 @@ def create_kernel(init):
|
60 | 57 | """
|
61 | 58 | shape = tuple(int(s) for s in shape)
|
62 | 59 | dtype = tf.as_dtype(dtype)
|
63 |
| - key = (tf.get_default_graph(), "irdft", shape, dtype.as_datatype_enum) |
64 |
| - matrix = _matrix_cache.get(key) |
65 |
| - if matrix is None: |
66 |
| - size = np.prod(shape) |
67 |
| - rank = len(shape) |
68 |
| - matrix = np.identity(size, dtype=np.float64).reshape((size,) + shape) |
69 |
| - for axis in range(rank): |
70 |
| - matrix = fftpack.rfft(matrix, axis=axis + 1) |
71 |
| - slices = (rank + 1) * [slice(None)] |
72 |
| - if shape[axis] % 2 == 1: |
73 |
| - slices[axis + 1] = slice(1, None) |
74 |
| - else: |
75 |
| - slices[axis + 1] = slice(1, -1) |
76 |
| - matrix[tuple(slices)] *= np.sqrt(2) |
77 |
| - matrix /= np.sqrt(size) |
78 |
| - matrix = np.reshape(matrix, (size, size)) |
79 |
| - matrix = tf.constant( |
80 |
| - matrix, dtype=dtype, name="irdft_" + "x".join([str(s) for s in shape])) |
81 |
| - _matrix_cache[key] = matrix |
82 |
| - return matrix |
| 60 | + size = np.prod(shape) |
| 61 | + rank = len(shape) |
| 62 | + matrix = np.identity(size, dtype=np.float64).reshape((size,) + shape) |
| 63 | + for axis in range(rank): |
| 64 | + matrix = fftpack.rfft(matrix, axis=axis + 1) |
| 65 | + slices = (rank + 1) * [slice(None)] |
| 66 | + if shape[axis] % 2 == 1: |
| 67 | + slices[axis + 1] = slice(1, None) |
| 68 | + else: |
| 69 | + slices[axis + 1] = slice(1, -1) |
| 70 | + matrix[tuple(slices)] *= np.sqrt(2) |
| 71 | + matrix /= np.sqrt(size) |
| 72 | + matrix = np.reshape(matrix, (size, size)) |
| 73 | + return tf.constant( |
| 74 | + matrix, dtype=dtype, name="irdft_" + "x".join([str(s) for s in shape])) |
0 commit comments