Skip to content

Commit cf20e36

Browse files
Johannes Ballécopybara-github
authored andcommitted
Reimplements RDFT kernel reparameterization using tf.signal.
PiperOrigin-RevId: 367329695 Change-Id: Ibcbc6cc96b48560494a3b9e2c0e2305c90e7a138
1 parent f74fc2a commit cf20e36

File tree

10 files changed

+75
-163
lines changed

10 files changed

+75
-163
lines changed

BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ py_library(
2525
"//tensorflow_compression/python/ops:math_ops",
2626
"//tensorflow_compression/python/ops:padding_ops",
2727
"//tensorflow_compression/python/ops:soft_round_ops",
28-
"//tensorflow_compression/python/ops:spectral_ops",
2928
"//tensorflow_compression/python/util:packed_tensors",
3029
],
3130
)

tensorflow_compression/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from tensorflow_compression.python.ops.math_ops import *
3737
from tensorflow_compression.python.ops.padding_ops import *
3838
from tensorflow_compression.python.ops.soft_round_ops import *
39-
from tensorflow_compression.python.ops.spectral_ops import *
4039

4140
from tensorflow_compression.python.util.packed_tensors import *
4241
# pylint: enable=wildcard-import

tensorflow_compression/all_tests.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from tensorflow_compression.python.ops.padding_ops_test import *
4242
from tensorflow_compression.python.ops.range_coding_ops_test import *
4343
from tensorflow_compression.python.ops.soft_round_ops_test import *
44-
from tensorflow_compression.python.ops.spectral_ops_test import *
4544

4645
from tensorflow_compression.python.util.packed_tensors_test import *
4746
# pylint: enable=wildcard-import

tensorflow_compression/python/layers/BUILD

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,7 @@ py_library(
3838
name = "parameters",
3939
srcs = ["parameters.py"],
4040
srcs_version = "PY3",
41-
deps = [
42-
"//tensorflow_compression/python/ops:math_ops",
43-
"//tensorflow_compression/python/ops:spectral_ops",
44-
],
41+
deps = ["//tensorflow_compression/python/ops:math_ops"],
4542
)
4643

4744
py_test(

tensorflow_compression/python/layers/parameters.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from typing import Any, Dict
1919
import tensorflow as tf
2020
from tensorflow_compression.python.ops import math_ops
21-
from tensorflow_compression.python.ops import spectral_ops
2221

2322

2423
__all__ = [
@@ -78,50 +77,64 @@ class RDFTParameter(Parameter):
7877
(see https://en.wikipedia.org/wiki/Discrete_Fourier_transform)
7978
8079
Attributes:
81-
dc: Boolean. The `dc` parameter provided on initialization.
8280
shape: `tf.TensorShape`. The shape of the convolution kernel.
83-
rdft: `tf.Variable`. The RDFT of the kernel.
81+
real: `tf.Variable`. The real part of the RDFT of the kernel.
82+
imag: `tf.Variable`. The imaginary part of the RDFT of the kernel.
8483
"""
8584

86-
def __init__(self, initial_value, name=None, dc=True, shape=None, dtype=None):
85+
def __init__(self, initial_value, name=None, shape=None, dtype=None):
8786
"""Initializer.
8887
8988
Args:
9089
initial_value: `tf.Tensor` or `None`. The initial value of the kernel. If
9190
not provided, its `shape` must be given, and the initial value of the
9291
parameter will be undefined.
9392
name: String. The name of the kernel.
94-
dc: Boolean. If `False`, the DC component of the kernel RDFTs is not
95-
represented, forcing the filters to be highpass. Defaults to `True`.
9693
shape: `tf.TensorShape` or compatible. Ignored unless `initial_value is
9794
None`.
9895
dtype: `tf.dtypes.DType` or compatible. DType of this parameter. If not
9996
given, inferred from `initial_value`.
10097
"""
10198
super().__init__(name=name)
102-
self._dc = bool(dc)
10399
if initial_value is None:
104100
if shape is None:
105101
raise ValueError("If initial_value is None, shape must be specified.")
106102
initial_value = tf.zeros(shape, dtype=dtype)
107103
else:
108104
initial_value = tf.convert_to_tensor(initial_value, dtype=dtype)
109105
self._shape = initial_value.shape
110-
self._matrix = spectral_ops.irdft_matrix(
111-
self.shape[:-2], dtype=initial_value.dtype)
112-
if not self.dc:
113-
self._matrix = self._matrix[:, 1:]
114-
initial_value = tf.reshape(
115-
initial_value, (-1, self.shape[-2] * self.shape[-1]))
116-
initial_value = tf.linalg.matmul(
117-
self._matrix, initial_value, transpose_a=True)
106+
self._dtype = initial_value.dtype
107+
108+
if self.shape.rank == 3:
109+
initial_value = tf.transpose(initial_value, (1, 2, 0))
110+
initial_value = tf.signal.rfft(initial_value)
111+
elif self.shape.rank == 4:
112+
initial_value = tf.transpose(initial_value, (2, 3, 0, 1))
113+
initial_value = tf.signal.rfft2d(initial_value)
114+
elif self.shape.rank == 5:
115+
initial_value = tf.transpose(initial_value, (3, 4, 0, 1, 2))
116+
initial_value = tf.signal.rfft3d(initial_value)
117+
else:
118+
raise ValueError(
119+
f"Expected kernel tensor of rank 3, 4, or 5; received shape "
120+
f"{self._shape}.")
121+
self._norm = tf.constant(
122+
self.shape[:-2].num_elements() ** .5, initial_value.dtype)
123+
initial_value /= self._norm
124+
# We split the variable into real and imaginary parts to avoid issues with
125+
# complex-valued variables being unsupported when saving models, etc.
126+
real = tf.math.real(initial_value)
127+
imag = tf.math.imag(initial_value)
128+
real_name = imag_name = None
118129
if name is not None:
119-
name = f"{name}_rdft"
120-
self.rdft = tf.Variable(initial_value, name=name)
130+
real_name = f"{name}_real"
131+
imag_name = f"{name}_imag"
132+
self.real = tf.Variable(real, name=real_name)
133+
self.imag = tf.Variable(imag, name=imag_name)
121134

122135
@property
123-
def dc(self) -> bool:
124-
return self._dc
136+
def dtype(self) -> tf.dtypes.DType:
137+
return self._dtype
125138

126139
@property
127140
def shape(self) -> tf.TensorShape:
@@ -130,15 +143,24 @@ def shape(self) -> tf.TensorShape:
130143
@tf.Module.with_name_scope
131144
def __call__(self) -> tf.Tensor:
132145
"""Computes and returns the convolution kernel as a `tf.Tensor`."""
133-
return tf.reshape(tf.linalg.matmul(self._matrix, self.rdft), self.shape)
146+
rdft = tf.dtypes.complex(self.real, self.imag) * self._norm
147+
if self.shape.rank == 3:
148+
kernel = tf.signal.irfft(rdft, fft_length=self.shape[:-2])
149+
return tf.transpose(kernel, (2, 0, 1))
150+
elif self.shape.rank == 4:
151+
kernel = tf.signal.irfft2d(rdft, fft_length=self.shape[:-2])
152+
return tf.transpose(kernel, (2, 3, 0, 1))
153+
else:
154+
assert self.shape.rank == 5, self.shape
155+
kernel = tf.signal.irfft3d(rdft, fft_length=self.shape[:-2])
156+
return tf.transpose(kernel, (2, 3, 4, 0, 1))
134157

135158
def get_config(self) -> Dict[str, Any]:
136159
config = super().get_config()
137160
config.update(
138161
initial_value=None,
139-
dc=self.dc,
140-
shape=tuple(self.shape),
141-
dtype=self.rdft.dtype.name,
162+
shape=tuple(map(int, self.shape)),
163+
dtype=self.dtype.name,
142164
)
143165
return config
144166

@@ -220,7 +242,7 @@ def get_config(self) -> Dict[str, Any]:
220242
initial_value=None,
221243
minimum=self.minimum,
222244
offset=self.offset,
223-
shape=tuple(self.variable.shape),
245+
shape=tuple(map(int, self.variable.shape)),
224246
dtype=self.variable.dtype.name,
225247
)
226248
return config

tensorflow_compression/python/layers/parameters_test.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ==============================================================================
1515
"""Tests of parameters."""
1616

17+
from absl.testing import parameterized
1718
import tensorflow as tf
1819
from tensorflow_compression.python.layers import parameters
1920

@@ -49,24 +50,33 @@ def test_converts_to_tensor(self):
4950
self.assertEqual(value.dtype.name, converted.dtype.name)
5051

5152

52-
class RDFTParameterTest(ParameterTest, tf.test.TestCase):
53+
class RDFTParameterTest(ParameterTest, tf.test.TestCase,
54+
parameterized.TestCase):
5355

5456
cls = parameters.RDFTParameter
55-
kwargs = dict(name="hello_rdft", dc=True)
57+
kwargs = dict(name="rdft_kernel")
5658
shape = (3, 3, 1, 2)
5759

58-
def test_initial_value_is_reproduced_without_dc(self):
59-
initial_value = tf.random.uniform(self.shape, dtype=tf.float32)
60-
parameter = self.cls(initial_value, dc=False)
61-
expected_value = initial_value - tf.reduce_mean(
62-
initial_value, axis=(0, 1), keepdims=True)
63-
self.assertAllClose(expected_value, parameter(), atol=1e-6, rtol=0)
60+
# TODO(jonycgn): Find out why 3D RFFT gradients are not implemented in TF.
61+
@parameterized.parameters((7, 3, 2), (5, 3, 1, 2))
62+
def test_gradients_propagate(self, *shape):
63+
initial_value = tf.random.uniform(shape, dtype=tf.float32)
64+
parameter = self.cls(initial_value, **self.kwargs)
65+
rand = tf.random.uniform(shape)
66+
with tf.GradientTape() as tape:
67+
loss = tf.reduce_sum(rand * parameter())
68+
gradients = tape.gradient(loss, parameter.variables)
69+
self.assertLen(gradients, 2)
70+
self.assertNotAllClose(
71+
tf.zeros_like(gradients[0]), gradients[0], atol=1e-1, rtol=0)
72+
self.assertNotAllClose(
73+
tf.zeros_like(gradients[1]), gradients[1], atol=1e-1, rtol=0)
6474

6575

6676
class GDNParameterTest(ParameterTest, tf.test.TestCase):
6777

6878
cls = parameters.GDNParameter
69-
kwargs = dict(name="hello_gdn")
79+
kwargs = dict(name="gdn_parameter")
7080
shape = (2, 1, 3)
7181

7282
def test_initial_value_is_reproduced_with_minimum(self):

tensorflow_compression/python/layers/signal_conv_test.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,18 +33,19 @@ def test_invalid_data_format_raises_error(self):
3333
def test_variables_are_enumerated(self):
3434
layer = signal_conv.SignalConv2D(3, 1, use_bias=True)
3535
layer.build((None, None, None, 2))
36-
self.assertLen(layer.weights, 2)
37-
self.assertLen(layer.trainable_weights, 2)
36+
self.assertLen(layer.weights, 3)
37+
self.assertLen(layer.trainable_weights, 3)
3838
weight_names = [w.name for w in layer.weights]
39-
self.assertSameElements(weight_names, ["kernel_rdft:0", "bias:0"])
39+
self.assertSameElements(
40+
weight_names, ["kernel_real:0", "kernel_imag:0", "bias:0"])
4041

4142
def test_bias_variable_is_not_unnecessarily_created(self):
4243
layer = signal_conv.SignalConv2D(5, 3, use_bias=False)
4344
layer.build((None, None, None, 3))
44-
self.assertLen(layer.weights, 1)
45-
self.assertLen(layer.trainable_weights, 1)
45+
self.assertLen(layer.weights, 2)
46+
self.assertLen(layer.trainable_weights, 2)
4647
weight_names = [w.name for w in layer.weights]
47-
self.assertSameElements(weight_names, ["kernel_rdft:0"])
48+
self.assertSameElements(weight_names, ["kernel_real:0", "kernel_imag:0"])
4849

4950
def test_variables_are_not_enumerated_when_overridden(self):
5051
layer = signal_conv.SignalConv2D(1, 1)
@@ -58,7 +59,7 @@ def test_variables_trainable_state_follows_layer(self):
5859
layer = signal_conv.SignalConv2D(1, 1, use_bias=True)
5960
layer.trainable = False
6061
layer.build((None, None, None, 1))
61-
self.assertLen(layer.weights, 2)
62+
self.assertLen(layer.weights, 3)
6263
self.assertEmpty(layer.trainable_weights)
6364

6465
def test_attributes_cannot_be_set_after_build(self):
@@ -107,7 +108,7 @@ def test_variables_receive_gradients(self):
107108
with tf.GradientTape() as g:
108109
y = layer(x)
109110
grads = g.gradient(y, layer.trainable_weights)
110-
self.assertLen(grads, 2)
111+
self.assertLen(grads, 3)
111112
self.assertNotIn(None, grads)
112113
grad_shapes = [tuple(g.shape) for g in grads]
113114
weight_shapes = [tuple(w.shape) for w in layer.trainable_weights]

tensorflow_compression/python/ops/BUILD

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,19 +60,6 @@ py_test(
6060
deps = [":soft_round_ops"],
6161
)
6262

63-
py_library(
64-
name = "spectral_ops",
65-
srcs = ["spectral_ops.py"],
66-
srcs_version = "PY3",
67-
)
68-
69-
py_test(
70-
name = "spectral_ops_test",
71-
srcs = ["spectral_ops_test.py"],
72-
python_version = "PY3",
73-
deps = [":spectral_ops"],
74-
)
75-
7663
filegroup(
7764
name = "py_src",
7865
srcs = glob(["*.py"]),

tensorflow_compression/python/ops/spectral_ops.py

Lines changed: 0 additions & 70 deletions
This file was deleted.

tensorflow_compression/python/ops/spectral_ops_test.py

Lines changed: 0 additions & 32 deletions
This file was deleted.

0 commit comments

Comments
 (0)