Skip to content

Commit 3f4c82a

Browse files
Johannes Ball?copybara-github
authored andcommitted
Removes caching mechanism from common constants.
PiperOrigin-RevId: 269857283 Change-Id: I51776f5ad25260440eeeedd81a45af1a2966e397
1 parent 3842edc commit 3f4c82a

File tree

1 file changed

+15
-23
lines changed

1 file changed

+15
-23
lines changed

tensorflow_compression/python/ops/spectral_ops.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@
2323
import tensorflow.compat.v1 as tf
2424

2525

26-
_matrix_cache = {}
27-
28-
2926
__all__ = [
3027
"irdft_matrix",
3128
]
@@ -60,23 +57,18 @@ def create_kernel(init):
6057
"""
6158
shape = tuple(int(s) for s in shape)
6259
dtype = tf.as_dtype(dtype)
63-
key = (tf.get_default_graph(), "irdft", shape, dtype.as_datatype_enum)
64-
matrix = _matrix_cache.get(key)
65-
if matrix is None:
66-
size = np.prod(shape)
67-
rank = len(shape)
68-
matrix = np.identity(size, dtype=np.float64).reshape((size,) + shape)
69-
for axis in range(rank):
70-
matrix = fftpack.rfft(matrix, axis=axis + 1)
71-
slices = (rank + 1) * [slice(None)]
72-
if shape[axis] % 2 == 1:
73-
slices[axis + 1] = slice(1, None)
74-
else:
75-
slices[axis + 1] = slice(1, -1)
76-
matrix[tuple(slices)] *= np.sqrt(2)
77-
matrix /= np.sqrt(size)
78-
matrix = np.reshape(matrix, (size, size))
79-
matrix = tf.constant(
80-
matrix, dtype=dtype, name="irdft_" + "x".join([str(s) for s in shape]))
81-
_matrix_cache[key] = matrix
82-
return matrix
60+
size = np.prod(shape)
61+
rank = len(shape)
62+
matrix = np.identity(size, dtype=np.float64).reshape((size,) + shape)
63+
for axis in range(rank):
64+
matrix = fftpack.rfft(matrix, axis=axis + 1)
65+
slices = (rank + 1) * [slice(None)]
66+
if shape[axis] % 2 == 1:
67+
slices[axis + 1] = slice(1, None)
68+
else:
69+
slices[axis + 1] = slice(1, -1)
70+
matrix[tuple(slices)] *= np.sqrt(2)
71+
matrix /= np.sqrt(size)
72+
matrix = np.reshape(matrix, (size, size))
73+
return tf.constant(
74+
matrix, dtype=dtype, name="irdft_" + "x".join([str(s) for s in shape]))

0 commit comments

Comments
 (0)